forked from janithnw/pan2020_authorship_verification
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpan20_large_predict.py
89 lines (71 loc) · 2.63 KB
/
pan20_large_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd
import numpy as np
import pickle
import nltk
from nltk.tokenize import TweetTokenizer
from writeprints import get_writeprints_transformer, prepare_entry
from tqdm import tqdm
import json
import re
import os
import string
import argparse
import sys
import torch
TRANSFORMER_FILE = 'transformers.p'
MODEL_FILE = 'best_model.pt'
class PANDatasetIterator(torch.utils.data.IterableDataset):
def __init__(self, f_in, transformer, scaler):
self.f_in = f_in
self.transformer = transformer
self.scaler = scaler
def mapper(self, line):
d = json.loads(line)
x1 = scaler.transform(transformer.transform([prepare_entry(d['pair'][0])]).todense())
x2 = scaler.transform(transformer.transform([prepare_entry(d['pair'][1])]).todense())
x = np.abs(x1 - x2)[0, :].astype('float32')
return x, d['id']
def __iter__(self):
f_itr = open(self.f_in, 'r')
return map(self.mapper, f_itr)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Prediction Script: PAN 2020 - Janith Weerasinghe')
parser.add_argument('-i', type=str,
help='Evaluaiton dir')
parser.add_argument('-o', type=str,
help='Output dir')
args = parser.parse_args()
# validate:
if not args.i:
raise ValueError('Eval dir path is required')
if not args.o:
raise ValueError('Output dir path is required')
input_file = os.path.join(args.i, 'pairs.jsonl')
output_file = os.path.join(args.o, 'answers.jsonl')
print("Writing answers to:", output_file , file=sys.stderr)
with open(TRANSFORMER_FILE, 'rb') as f:
transformer, scaler = pickle.load(f)
with open(MODEL_FILE, 'rb') as f:
best_model = torch.load(f)
device = torch.device('cpu')
ds = PANDatasetIterator(input_file, transformer, scaler)
test_loader = torch.utils.data.DataLoader(dataset=ds, batch_size=1000)
fout = open(output_file, 'a')
c = 0
with torch.no_grad():
for x, ids in test_loader:
x = x.to(device)
outputs = best_model(x)
probs = outputs.numpy()[:, 0].astype(float)
for i in range(len(ids)):
d = {
'id': ids[i],
'value': probs[i]
}
json.dump(d, fout)
fout.write('\n')
c += len(ids)
print(c, file=sys.stderr)
print('Written to', output_file, flush=True, file=sys.stderr)
fout.close()
print("Execution complete", file=sys.stderr)