-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpredict.py
90 lines (74 loc) · 2.52 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from data import Dictionary, Corpus, PAD_INDEX
from mst import mst
def plot(S_arc, heads):
fig, ax = plt.subplots()
# Make a 0/1 gold adjacency matrix.
n = heads.size(1)
G = np.zeros((n, n))
heads = heads.squeeze().data.numpy()
G[heads, np.arange(n)] = 1.
im = ax.imshow(G, vmin=0, vmax=1)
fig.colorbar(im)
plt.savefig('img/gold.pdf')
plt.cla()
# Plot the predicted adjacency matrix
A = F.softmax(S_arc.squeeze(0), dim=0)
fig, ax = plt.subplots()
im = ax.imshow(A.data.numpy(), vmin=0, vmax=1)
fig.colorbar(im)
plt.savefig('img/a.pdf')
plt.cla()
plt.clf()
def predict(model, words, tags):
assert type(words) == type(tags)
if type(words) == type(tags) == list:
# Convert the lists into input for the PyTorch model.
words = Variable(torch.LongTensor([words]))
tags = Variable(torch.LongTensor([tags]))
# Dissable dropout.
model.eval()
# Predict arc and label score matrices.
S_arc, S_lab = model(words, tags)
# Predict heads
S = S_arc[0].data.numpy()
heads = mst(S)
# Predict labels
S_lab = S_lab[0]
select = torch.LongTensor(heads).unsqueeze(0).expand(S_lab.size(0), -1)
select = Variable(select)
selected = torch.gather(S_lab, 1, select.unsqueeze(1)).squeeze(1)
_, labels = selected.max(dim=0)
labels = labels.data.numpy()
return heads, labels
def predict_batch(S_arc, S_lab, tags):
# Predict heads
S = S_arc.data.numpy()
heads = mst(S)
# Predict labels
select = torch.LongTensor(heads).unsqueeze(0).expand(S_lab.size(0), -1)
select = Variable(select)
selected = torch.gather(S_lab, 1, select.unsqueeze(1)).squeeze(1)
_, labels = selected.max(dim=0)
labels = labels.data.numpy()
return heads, labels
if __name__ == '__main__':
data_path = '../../stanford-ptb'
vocab_path = 'vocab/train'
model_path = 'models/model.pt'
dictionary = Dictionary(vocab_path)
corpus = Corpus(data_path=data_path, vocab_path=vocab_path)
model = torch.load(model_path)
batches = corpus.train.batches(1)
words, tags, heads, labels = next(batches)
S_arc, S_lab = model(words, tags)
plot(S_arc, heads)
words = tags = [1, 2, 3, 4]
heads_pred, labels_pred = predict(model, words, tags)
print(heads_pred, '\n', heads[0].data.numpy())
print(labels_pred, '\n', labels[0].data.numpy())