-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
42 lines (32 loc) · 1.04 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import argparse
from pprint import pprint
import os
from typing import List
from rouge import Rouge
from seq2seq.utils import calculate_rouge
from util import data_io
parser = argparse.ArgumentParser()
parser.add_argument(
"--pred_file",
default=os.environ["HOME"] + "/gunther/data/transformer_trained/test_rare_epoch_20.pred",
type=str,
)
parser.add_argument(
"--target_file",
default=os.environ["HOME"] + "/gunther/Response-Generation-Baselines/processed_output/test_rare.tgt",
type=str,
)
def calc_rouge_scores(pred:List[str],tgt:List[str]):
rouge = Rouge()
scores = rouge.get_scores(pred, tgt, avg=True)
scores = {
"f1-scores": {s: v for s, d in scores.items() for k, v in d.items() if
k == "f"},
"huggingface-rouge": calculate_rouge(pred, tgt)
}
return scores
if __name__ == "__main__":
args = parser.parse_args()
pred = list(data_io.read_lines(args.pred_file))
tgt = list(data_io.read_lines(args.target_file))
pprint(calc_rouge_scores(pred,tgt))