-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
112 lines (92 loc) · 4.61 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import time
import string
import argparse
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data
import torch.nn.functional as F
import numpy as np
from nltk.metrics.distance import edit_distance
from utils import CTCLabelConverter, AttnLabelConverter, Averager
from dataset import hierarchical_dataset, AlignCollate
from model import Model
def validation(model, criterion, evaluation_loader, converter, opt, device):
""" validation or evaluation """
n_correct = 0
norm_ED = 0
length_of_data = 0
infer_time = 0
valid_loss_avg = Averager()
for i, (image_tensors, labels) in enumerate(evaluation_loader):
batch_size = image_tensors.size(0)
length_of_data = length_of_data + batch_size
image = image_tensors.to(device)
# For max length prediction
length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length)
start_time = time.time()
if 'CTC' in opt.Prediction:
preds = model(image, text_for_pred)
forward_time = time.time() - start_time
# Calculate evaluation loss for CTC decoder.
preds_size = torch.IntTensor([preds.size(1)] * batch_size)
# permute 'preds' to use CTCloss format
cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)
if opt.decode == 'greedy':
# Select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_index = preds_index.view(-1)
preds_str = converter.decode_greedy(preds_index.data, preds_size.data)
elif opt.decode == 'beamsearch':
preds_str = converter.decode_beamsearch(preds, beamWidth=2)
else:
preds = model(image, text_for_pred, is_train=False)
forward_time = time.time() - start_time
preds = preds[:, :text_for_loss.shape[1] - 1, :]
target = text_for_loss[:, 1:] # without [GO] Symbol
cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))
# select max probabilty (greedy decoding) then decode index to character
_, preds_index = preds.max(2)
preds_str = converter.decode(preds_index, length_for_pred)
labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
infer_time += forward_time
valid_loss_avg.add(cost)
# calculate accuracy & confidence score
preds_prob = F.softmax(preds, dim=2)
preds_max_prob, _ = preds_prob.max(dim=2)
confidence_score_list = []
for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
if 'Attn' in opt.Prediction:
gt = gt[:gt.find('[s]')]
pred_EOS = pred.find('[s]')
pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])
pred_max_prob = pred_max_prob[:pred_EOS]
if pred == gt:
n_correct += 1
'''
(old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
"For each word we calculate the normalized edit distance to the length of the ground truth transcription."
if len(gt) == 0:
norm_ED += 1
else:
norm_ED += edit_distance(pred, gt) / len(gt)
'''
# ICDAR2019 Normalized Edit Distance
if len(gt) == 0 or len(pred) ==0:
norm_ED += 0
elif len(gt) > len(pred):
norm_ED += 1 - edit_distance(pred, gt) / len(gt)
else:
norm_ED += 1 - edit_distance(pred, gt) / len(pred)
# calculate confidence score (= multiply of pred_max_prob)
try:
confidence_score = pred_max_prob.cumprod(dim=0)[-1]
except:
confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
confidence_score_list.append(confidence_score)
# print(pred, gt, pred==gt, confidence_score)
accuracy = n_correct / float(length_of_data) * 100
norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance
return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data