-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathutils.py
94 lines (82 loc) · 2.98 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
def post_process_response(token_ids: List[int], tokenizer):
'''
Post-process the decoded sequence. Truncate from the first <eos>.
'''
eos_pos = len(token_ids)
for i, tok_id in enumerate(token_ids):
if tok_id == tokenizer.sep_token_id:
eos_pos = i
break
token_ids = token_ids[:eos_pos]
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = tokenizer.merge_subword(tokens)
return token_ids, tokens
def get_in_turn_repetition(pred: List[str], is_cn: bool = False):
'''
Get in-turn repetition.
'''
if len(pred) == 0:
return 1.0
if isinstance(pred[0], str):
pred = [tok.lower() for tok in pred]
if is_cn:
pred = "".join(pred)
tri_grams = set()
for i in range(len(pred) - 2):
tri_gram = tuple(pred[i:i + 3])
if tri_gram in tri_grams:
return True
tri_grams.add(tri_gram)
return False
def select_response(ids,
scores: List[float],
tokenizer,
max_dec_len: int = None,
num_return_sequences: int = 1,
keep_space: bool = True):
'''
Select response with the highest score.
'''
ids = ids.numpy().tolist()
scores = scores.numpy()
if len(ids) != len(scores) or (len(ids) % num_return_sequences) != 0:
raise ValueError("the length of `ids` is {}, but the `num_return_sequences` is {}".format(
len(ids), num_return_sequences))
group = []
tmp = []
for pred, score in zip(ids, scores):
pred_token_ids, pred_tokens = post_process_response(pred, tokenizer)
num_token = len(pred_token_ids)
if keep_space:
response = " ".join(pred_tokens)
else:
response = "".join(pred_tokens)
in_turn_repetition = get_in_turn_repetition(pred_tokens, True) or get_in_turn_repetition(pred_token_ids)
# not ending
if max_dec_len is not None and num_token >= max_dec_len:
score -= 1e3
elif in_turn_repetition:
score -= 1e3
tmp.append([response, score])
if len(tmp) == num_return_sequences:
group.append(tmp)
tmp = []
results = []
for preds in group:
preds = sorted(preds, key=lambda x: -x[1])
results.append(preds[0][0])
return results