-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_accuracy.py
57 lines (48 loc) · 1.91 KB
/
get_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import datasets
import pandas as pd
def evaluate_prompt(prompt, expected_answer):
import re
pattern = re.compile(r'Correct answer:\n[ABCDabcd]\)')
results = pattern.findall(prompt)
# print(f"{results = }")
if len(results) > 1:
for response in results:
if response[-2].upper() != results[0][-2].upper():
# print(f"Inconsistent results: {prompt = }")
return 0
pass
if len(results) == 0:
# print(f"No Answer found: {prompt = }")
return 0
if results[0][-2].upper() == expected_answer.upper():
return 1
return 0
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, choices=["EasyDataset", "HardDataset", "ComprehensiveDataset"])
args = parser.parse_args()
valid_cnt = 0
# dataset = f"EasyDataset"
# dataset = f"HardDataset"
dataset = args.dataset
test_set = pd.read_excel('chatbot_datasets/Data_OpenSource.xlsx', sheet_name=dataset)
correct_answer = [response[0] for response in test_set['Correct Answer']]
input_file = open(f"{dataset}.txt", 'r')
prompt_idx = 0
prompt = ""
for line in input_file:
if line == f'prompt: {prompt_idx}\n':
# print(f"evaluate prompt: {prompt}")
if prompt_idx != 0:
valid_cnt += evaluate_prompt(prompt, correct_answer[prompt_idx-1])
prompt_idx += 1
prompt = ""
else:
prompt += line
# evaluate last prompt
valid_cnt += evaluate_prompt(prompt, correct_answer[prompt_idx-1])
# print(f"{dataset:15} Valid: {valid_cnt[dataset]}")
print(f"{'# of correct responses ('+str(len(correct_answer))+' in total)':23}")
total_question_num = len(correct_answer)
print(f"{dataset:20} {valid_cnt} correct in {total_question_num}, accuracy: {valid_cnt/total_question_num:3f}")