-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaggregate_scores_mb.py
107 lines (86 loc) · 3.27 KB
/
aggregate_scores_mb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
from tabulate import tabulate
import sys
from coval.conll import util, reader
from coval.eval import evaluator
from pathlib import Path
import pandas as pd
def score_all():
allmetrics = [('mentions', evaluator.mentions), ('muc', evaluator.muc),
('bcub', evaluator.b_cubed), ('ceafe', evaluator.ceafe),
('lea', evaluator.lea)]
key_file = sys.argv[1]
NP_only = 'NP_only' in sys.argv
remove_nested = 'remove_nested' in sys.argv
keep_singletons = ('remove_singletons' not in sys.argv
and 'removIe_singleton' not in sys.argv)
min_span = False
if ('min_span' in sys.argv
or 'min_spans' in sys.argv
or 'min' in sys.argv):
min_span = True
has_gold_parse = util.check_gold_parse_annotation(key_file)
if not has_gold_parse:
util.parse_key_file(key_file)
key_file = key_file + ".parsed"
if 'all' in sys.argv:
metrics = allmetrics
else:
metrics = [(name, metric) for name, metric in allmetrics
if name in sys.argv]
if not metrics:
metrics = allmetrics
# -------------- score all sys files in a folder hierarchy which is /<seed 0>/sys.conll , /<seed 1>/sys.conll, ... -----------------
sys_base_dir = Path(sys.argv[2])
assert sys_base_dir.exists()
all_scores = []
for seed_dir in sys_base_dir.iterdir():
try:
seed = int(seed_dir.name)
except ValueError:
print(f"{seed_dir} is not a seed directory")
continue
for sys_file in seed_dir.iterdir():
if sys_file.suffix == ".conll":
scores = evaluate(key_file, sys_file.absolute(), metrics, NP_only, remove_nested, keep_singletons, min_span)
scores["seed"] = seed
scores = scores.set_index("seed", append=True)
all_scores.append(scores)
all_scores = pd.concat(all_scores)
all_scores = all_scores.unstack("measure").dropna(axis="columns")
scores_aggregated = all_scores.describe().loc[["mean", "std"]]
all_scores.to_pickle(sys_base_dir / "scores.pkl")
with (sys_base_dir / "scores.txt").open("w") as f:
f.write(f"""
INDIVIDUAL SCORES
-----------------
{all_scores.to_csv()}
{tabulate(all_scores, headers="keys")}
AGGREGATED SCORES
-----------------
{scores_aggregated.to_csv()}
{tabulate(scores_aggregated, headers="keys")}
""")
def evaluate(key_file, sys_file, metrics, NP_only, remove_nested,
keep_singletons, min_span):
doc_coref_infos = reader.get_coref_infos(key_file, sys_file, NP_only,
remove_nested, keep_singletons, min_span)
conll = 0
conll_subparts_num = 0
scores = {}
for name, metric in metrics:
recall, precision, f1 = evaluator.evaluate_documents(doc_coref_infos,
metric,
beta=1)
if name in ["muc", "bcub", "ceafe"]:
conll += f1
conll_subparts_num += 1
scores[name] = [recall, precision, f1]
if conll_subparts_num == 3:
conll = (conll / 3) * 100
scores["conll"] = [None, None, conll]
scores = pd.DataFrame(scores, index=["R", "P", "F1"])
scores.index.name = "measure"
return scores
if __name__ == '__main__':
score_all()