-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathsearch.py
168 lines (147 loc) · 5.8 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
from data import END, UNK
"""
Beam search by batch
need model has two functions:
(1) decode_step
(2) prepare_incremental_input
"""
###########
##rewrite##
###########
class Hypothesis(object):
def __init__(self, state_dict, seq, score):
#state_dict: hidden states of the last step (has not yet consider seq[-1])
#seq: current generated sequence
#score: accumlated score so far
self.state_dict = state_dict
self.seq = seq
self.score = score
def is_completed(self):
###########
##rewrite##
###########
if self.seq[-1] == END:
return True
return False
def __len__(self):
return len(self.seq)
class Beam(object):
"""each beam for a test instance"""
def __init__(self, beam_size, min_time_step, max_time_step, hypotheses, device):
# hypotheses are the collection of alive hypotheses
self.beam_size = beam_size
self.min_time_step = min_time_step
self.max_time_step = max_time_step
self.completed_hypotheses = []
self.steps = 0
self.hypotheses = hypotheses
self.device = device
def merge_score(self, prev_hyp, step):
# step has two attributes: token and score
###########
##rewrite##
###########
token, score = step
prefix = prev_hyp.seq
if token == UNK:
return float('-inf')
new_score = prev_hyp.score + score
return new_score
def update(self, new_states, last_steps):
# last_steps: list (#num_hypotheses) of list (#beam_size) of (token, score)
candidates = []
for prev_hyp_idx, steps in enumerate(last_steps):
for step in steps:
token = step[0]
score = self.merge_score(self.hypotheses[prev_hyp_idx], step)
candidates.append((prev_hyp_idx, token, score))
candidates.sort(key=lambda x:x[-1], reverse=True)
live_nyp_num = self.beam_size - len(self.completed_hypotheses)
candidates = candidates[:live_nyp_num]
# candidates: list of triples (prev_hyp_idx, token, score)
new_hyps = []
_prev_hyp_idx = torch.tensor([ x[0] for x in candidates]).cuda(self.device)
_split_state = dict() # key => list
for k, v in new_states.items():
split_dim = 1 if len(v.size()) >= 3 else 0
_split_state[k] = v.index_select(split_dim, _prev_hyp_idx).split(1, dim=split_dim)
for idx, (prev_hyp_idx, token, score) in enumerate(candidates):
state = dict()
for k, v in _split_state.items():
state[k] = _split_state[k][idx]
seq = self.hypotheses[prev_hyp_idx].seq + [token]
new_hyps.append(Hypothesis(state, seq, score))
self.hypotheses = []
for hyp in new_hyps:
if hyp.is_completed():
if len(hyp)-2 >= self.min_time_step:
self.completed_hypotheses.append(hyp)
else:
self.hypotheses.append(hyp)
self.steps += 1
#self.print_everything()
def completed(self):
if len(self.completed_hypotheses) < self.beam_size and self.steps < self.max_time_step:
return False
return True
def get_k_best(self, k, alpha):
if len(self.completed_hypotheses) == 0:
self.completed_hypotheses = self.hypotheses
self.completed_hypotheses.sort(key=lambda x:x.score/((1+len(x.seq))**alpha), reverse=True)
return self.completed_hypotheses[:k]
def print_everything(self):
print ('alive:')
for x in self.hypotheses:
print (x.seq)
print ('completed:')
for x in self.completed_hypotheses:
print (x.seq)
def search_by_batch(model, beams, mem_dict):
def ready_to_submit(hypotheses):
inp = model.prepare_incremental_input([hyp.seq[-1:] for hyp in hypotheses])
concat_hyps= dict()
for hyp in hypotheses:
for k, v in hyp.state_dict.items():
concat_hyps[k] = concat_hyps.get(k, []) + [v]
for k, v in concat_hyps.items():
if len(v[0].size()) >= 3:
concat_hyps[k] = torch.cat(v, 1)
else:
concat_hyps[k] = torch.cat(v, 0)
return concat_hyps, inp
while True:
hypotheses = []
indices = []
offset = -1
for idx, beam in enumerate(beams):
if not beam.completed():
for hyp in beam.hypotheses:
hypotheses.append(hyp)
indices.append(idx)
offset = len(hyp.seq) - 1
if not hypotheses:
break
indices = torch.tensor(indices).cuda(beams[0].device)
state_dict, inp = ready_to_submit(hypotheses)
cur_mem_dict = dict()
for k, v in mem_dict.items():
if isinstance(v, list):
cur_mem_dict[k] = [v[i] for i in indices]
else:
cur_mem_dict[k] = v.index_select(1, indices)
state_dict, results = model.decode_step(inp, state_dict, cur_mem_dict, offset, beams[0].beam_size)
_len_each_beam = [len(beam.hypotheses) for beam in beams if not beam.completed()]
_state_dict_each_beam = [dict() for _ in _len_each_beam]
for k, v in state_dict.items():
split_dim = 1 if len(v.size()) >= 3 else 0
for i, x in enumerate(v.split(_len_each_beam, dim=split_dim)):
_state_dict_each_beam[i][k] = x
_pos = 0
_idx = 0
for beam in beams:
if not beam.completed():
_len = len(beam.hypotheses)
beam.update(_state_dict_each_beam[_idx], results[_pos:_pos+_len])
_pos += _len
_idx += 1