-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
49 lines (44 loc) · 2.01 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# -*- coding: utf-8 -*-
import torch
import numpy as np
from loader import load_data
"""
模型效果测试
"""
class Evaluator:
def __init__(self, config, model, logger):
self.config = config
self.model = model
self.logger = logger
self.valid_data = load_data(config["train_data_path"], config, shuffle=False)
self.stats_dict = {"correct":0, "wrong":0} #用于存储测试结果
def eval(self, epoch):
self.logger.info("开始测试第%d轮模型效果:" % epoch)
self.model.eval()
self.stats_dict = {"correct": 0, "wrong": 0} # 清空上一轮结果
for index, batch_data in enumerate(self.valid_data):
if torch.cuda.is_available():
batch_data = [d.cuda() for d in batch_data]
input_ids, labels = batch_data #输入变化时这里需要修改,比如多输入,多输出的情况
with torch.no_grad():
pred_results = self.model(input_ids) #不输入labels,使用模型当前参数进行预测
self.write_stats(labels, pred_results)
acc = self.show_stats()
return acc
def write_stats(self, labels, pred_results):
assert len(labels) == len(pred_results)
for true_label, pred_label in zip(labels, pred_results):
pred_label = torch.argmax(pred_label)
if int(true_label) == int(pred_label):
self.stats_dict["correct"] += 1
else:
self.stats_dict["wrong"] += 1
return
def show_stats(self):
correct = self.stats_dict["correct"]
wrong = self.stats_dict["wrong"]
self.logger.info("预测集合条目总量:%d" % (correct +wrong))
self.logger.info("预测正确条目:%d,预测错误条目:%d" % (correct, wrong))
self.logger.info("预测准确率:%f" % (correct / (correct + wrong)))
self.logger.info("--------------------")
return correct / (correct + wrong)