forked from janithnw/pan2020_authorship_verification
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpan20_small_predict.py
85 lines (71 loc) · 2.76 KB
/
pan20_small_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import pandas as pd
import numpy as np
import pickle
import nltk
from nltk.tokenize import TweetTokenizer
from writeprints import get_writeprints_transformer, prepare_entry
from tqdm import tqdm
import json
import re
import os
import string
import argparse
import sys
TRANSFORMER_FILE = 'transformers.p'
MODEL_FILE = 'LiniearRegressionModal.p'
def process_batch(transformer, scaler, clf, ids, preprocessed_docs1, preprocessed_docs2, output_file):
print('Extracting features:', len(ids), file=sys.stderr)
X1 = scaler.transform(transformer.transform(preprocessed_docs1).todense())
X2 = scaler.transform(transformer.transform(preprocessed_docs2).todense())
X = np.abs(X1 - X2)
print('Predicting...', file=sys.stderr)
probs = clf.predict_proba(X)[:, 1]
print('Writing to', output_file, file=sys.stderr)
with open(output_file, 'a') as f:
for i in range(len(ids)):
d = {
'id': ids[i],
'value': probs[i]
}
json.dump(d, f)
f.write('\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Prediction Script: PAN 2020 - Janith Weerasinghe')
parser.add_argument('-i', type=str,
help='Evaluaiton dir')
parser.add_argument('-o', type=str,
help='Output dir')
args = parser.parse_args()
# validate:
if not args.i:
raise ValueError('Eval dir path is required')
if not args.o:
raise ValueError('Output dir path is required')
input_file = os.path.join(args.i, 'pairs.jsonl')
output_file = os.path.join(args.o, 'answers.jsonl')
print("Writing answers to:", output_file , file=sys.stderr)
with open(TRANSFORMER_FILE, 'rb') as f:
transformer, scaler = pickle.load(f)
with open(MODEL_FILE, 'rb') as f:
clf = pickle.load(f)
preprocessed_docs1 = []
preprocessed_docs2 = []
ids = []
batch_size = 100
with open(input_file, 'r') as f:
i = 0
for l in f:
if i % 100 == 0:
print(i, file=sys.stderr)
i += 1
d = json.loads(l)
ids.append(d['id'])
preprocessed_docs1.append(prepare_entry(d['pair'][0]))
preprocessed_docs2.append(prepare_entry(d['pair'][1]))
if len(ids) >= batch_size:
process_batch(transformer, scaler, clf, ids, preprocessed_docs1, preprocessed_docs2, output_file)
preprocessed_docs1 = []
preprocessed_docs2 = []
ids = []
process_batch(transformer, scaler, clf, ids, preprocessed_docs1, preprocessed_docs2, output_file)
print("Execution complete", file=sys.stderr)