-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy patheval-acc.py
67 lines (59 loc) · 2.25 KB
/
eval-acc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import json
import sys
import os.path
from collections import defaultdict
import numpy as np
import torch
import utils
import config
q_path = utils.path_for(val=True, question=True)
with open(q_path, 'r') as fd:
q_json = json.load(fd)
a_path = utils.path_for(val=True, answer=True)
with open(a_path, 'r') as fd:
a_json = json.load(fd)
with open(os.path.join(config.qa_path, 'v2_mscoco_val2014_complementary_pairs.json')) as fd:
pairs = json.load(fd)
question_list = q_json['questions']
question_ids = [q['question_id'] for q in question_list]
questions = [q['question'] for q in question_list]
answer_list = a_json['annotations']
categories = [a['answer_type'] for a in answer_list] # {'yes/no', 'other', 'number'}
accept_condition = {
'number': (lambda x: id_to_cat[x] == 'number'),
'yes/no': (lambda x: id_to_cat[x] == 'yes/no'),
'other': (lambda x: id_to_cat[x] == 'other'),
'count': (lambda x: id_to_question[x].lower().startswith('how many')),
'all': (lambda x: True),
}
statistics = defaultdict(list)
for path in sys.argv[1:]:
log = torch.load(path)
ans = log['eval']
d = [(acc, ans) for (acc, ans, _) in sorted(zip(ans['accuracies'], ans['answers'], ans['idx']), key=lambda x: x[-1])]
accs = map(lambda x: x[0], d)
id_to_cat = dict(zip(question_ids, categories))
id_to_acc = dict(zip(question_ids, accs))
id_to_question = dict(zip(question_ids, questions))
for name, f in accept_condition.items():
for on_pairs in [False, True]:
acc = []
if on_pairs:
for a, b in pairs:
if not (f(a) and f(b)):
continue
if id_to_acc[a] == id_to_acc[b] == 1:
acc.append(1)
else:
acc.append(0)
else:
for x in question_ids:
if not f(x):
continue
acc.append(id_to_acc[x])
acc = np.mean(acc)
statistics[name, 'pair' if on_pairs else 'single'].append(acc)
for (name, pairness), accs in statistics.items():
mean = np.mean(accs)
std = np.std(accs, ddof=1)
print('{} ({})\t: {:.2f}% +- {}'.format(name, pairness, 100 * mean, 100 * std))