forked from grammarly/gector
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathm2_to_parallel.py
70 lines (58 loc) · 2.52 KB
/
m2_to_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import os
def get_all_coder_ids(edits):
coder_ids = set()
for edit in edits:
edit = edit.split("|||")
coder_id = int(edit[-1])
coder_ids.add(coder_id)
coder_ids = sorted(list(coder_ids))
return coder_ids
def m2_to_parallel(m2_files, ori, cor, drop_unchanged_samples, all):
ori_fout = None
if ori is not None:
ori_fout = open(ori, 'w', encoding="utf-8")
cor_fout = open(cor, 'w', encoding="utf-8")
# Do not apply edits with these error types
skip = {"noop", "UNK", "Um"}
for m2_file in m2_files:
import io
entries = io.open(m2_file, encoding="utf-8").read().strip().split("\n\n")
for entry in entries:
lines = entry.split("\n")
ori_sent = lines[0][2:] # Ignore "S "
cor_tokens = lines[0].split()[1:] # Ignore "S "
edits = lines[1:]
offset = 0
coders = get_all_coder_ids(edits) if all == True else [0]
for coder in coders:
offset = 0
cor_tokens = lines[0].split()[1:] # Ignore "S "
for edit in edits:
edit = edit.split("|||")
if edit[1] in skip: continue # Ignore certain edits
coder_id = int(edit[-1])
if coder_id != coder: continue # Ignore other coders
span = edit[0].split()[1:] # Ignore "A "
start = int(span[0])
end = int(span[1])
cor = edit[2].split()
cor_tokens[start + offset:end + offset] = cor
offset = offset - (end - start) + len(cor)
cor_sent = " ".join(cor_tokens)
if drop_unchanged_samples and ori_sent == cor_sent:
continue
if ori is not None:
ori_fout.write(ori_sent + "\n")
cor_fout.write(cor_sent + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data', nargs='+', default=[])
parser.add_argument('--erroneous_only', default=False, action='store_true', help='drop sentence without edits')
parser.add_argument('--all_annotators', default=False, action='store_true', help='get all annotators')
args = parser.parse_args()
ori = []
cor = []
for filepath in args.data:
path, ext = os.path.splitext(filepath)
m2_to_parallel([filepath], path+'.src', path+'.tgt', args.erroneous_only, args.all_annotators)