-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
46 lines (32 loc) · 1.45 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import matplotlib.pyplot as plt
import tensorflow as tf
import unicodedata
# list of unicode trope symbols, and corresponding names (minus "HEBREW ACCENT")
# includes SOF PASUQ (1473), meaning the :, not the meteg inside the word
trops = [chr(x) for x in range(1425, 1455)] + [chr(1475)]
trop_names = {x: " ".join(unicodedata.name(x).split(" ")[2:]) for x in trops}
trop_names["START"] = "START"
trop_names["END"] = "END"
def plot_attention_head(in_tokens, translated_tokens, attention):
# The model didn't generate `<START>` in the output. Skip it.
translated_tokens = translated_tokens[1:]
ax = plt.gca()
ax.matshow(attention)
ax.set_xticks(range(len(in_tokens)))
ax.set_yticks(range(len(translated_tokens)))
labels = [label.decode('utf-8') for label in in_tokens.numpy()]
ax.set_xticklabels(
labels, rotation=90)
labels = [label.decode('utf-8') for label in translated_tokens.numpy()]
ax.set_yticklabels(labels)
def plot_attention_weights(sentence, translated_tokens, attention_heads, tokenizers):
in_tokens = tf.convert_to_tensor([sentence])
in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor()
in_tokens = tokenizers.pt.lookup(in_tokens)[0]
fig = plt.figure(figsize=(16, 8))
for h, head in enumerate(attention_heads):
ax = fig.add_subplot(2, 4, h+1)
plot_attention_head(in_tokens, translated_tokens, head)
ax.set_xlabel(f'Head {h+1}')
plt.tight_layout()
plt.show()