-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
75 lines (65 loc) · 1.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
'''
Module : utils
Author: Nasibullah ([email protected])
'''
import sys
import os
import re
import random
import unicodedata
import numpy as np
import torch
class Utils:
'''
Generic utility function class.
'''
@staticmethod
def set_seed(seed):
'''
For reproducibility
'''
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
@staticmethod
def unicodeToAscii(s):
'''
Turn a Unicode string to plain ASCII,
Thanks to https://stackoverflow.com/a/518232/2809427
'''
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
@staticmethod
def normalizeString(s):
'''
Lowercase, trim, and remove non-letter
'''
s = s.lower().strip()
s = ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s
@staticmethod
def target_tensor_to_caption(voc,target):
'''
Convert target tensor to Caption
'''
gnd_trh = []
lend = target.size()[1]
for i in range(lend):
tmp = ' '.join(voc.index2word[x.item()] for x in targets[:,i])
gnd_trh.append(tmp)
return gnd_trh
@staticmethod
def maskNLLLoss(inp, target, mask, device):
'''
Masked cross-entropy loss calculation
'''
nTotal = mask.sum()
crossEntropy = -torch.log(torch.gather(inp.squeeze(0), 1, target.view(-1, 1)).squeeze(1).float())
loss = crossEntropy.masked_select(mask).mean()
loss = loss.to(device)
return loss, nTotal.item()