-
Notifications
You must be signed in to change notification settings - Fork 1
/
predict.py
112 lines (88 loc) · 3.35 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import logging
import pickle
from pathlib import Path
import torch
import transformers
from tqdm import tqdm
from load import load_data_from_file
from utils import dump_args, init_logger
logger = logging.getLogger("bert.predict")
def main(args: argparse.Namespace) -> None:
device = torch.device(args.device)
args.output_path.mkdir(exist_ok=True, parents=True)
with (args.model_path / "label_encoder.pk").open("rb") as file:
label_encoder = pickle.load(file)
test_loader, _ = load_data_from_file(
args.test_path,
1,
args.token_column,
args.predict_column,
args.lang_model_name,
512,
args.separator,
args.pad_label,
args.null_label,
device,
label_encoder,
False,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.lang_model_name, use_fast=True
)
model = torch.load(args.model_path / "model.pt", map_location=args.device)
model.fine_tune = False
model.eval()
list_labels = []
logger.info("Predicting tags")
for test_x, _, mask, _ in tqdm(test_loader):
logits = model.forward(test_x, mask)
preds = torch.argmax(logits, 2)
end = mask.argmin(1) - 1
labels = label_encoder.inverse_transform(preds[0][1:end].tolist())
list_labels.append(labels)
in_path = args.test_path
out_path = args.output_path / args.output_name
with in_path.open() as in_file, out_path.open("w") as out_file:
sentence_idx = 0
label_idx = 0
for line in in_file:
if line.startswith("#"):
out_file.write(line)
elif line not in [" ", "\n"]:
tokens = line.strip().split(args.separator)
token = tokens[args.token_column]
gold = tokens[args.predict_column]
pred = list_labels[sentence_idx][label_idx]
out_file.write(args.separator.join([token, gold, pred]) + "\n")
subtokens = tokenizer.encode(token, add_special_tokens=False)
label_idx += len(subtokens)
else:
assert label_idx == len(list_labels[sentence_idx])
out_file.write("\n")
sentence_idx += 1
label_idx = 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("test_path", type=Path)
parser.add_argument("model_path", type=Path)
parser.add_argument("token_column", type=int)
parser.add_argument("predict_column", type=int)
parser.add_argument("lang_model_name", type=str)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--output_path", type=Path, default="output")
parser.add_argument("--output_name", type=str, default="predict.conllu")
parser.add_argument("--separator", type=str, default=" ")
parser.add_argument("--pad_label", type=str, default="<pad>")
parser.add_argument("--null_label", type=str, default="<X>")
parser.add_argument("--device", default="cpu")
parser.add_argument(
"--log-all",
action="store_true",
help="Enable logging of everything, including libraries like transformers",
)
args = parser.parse_args()
log_name = None if args.log_all else "bert"
init_logger(log_name=log_name)
dump_args(args)
main(args)