-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathEMLLoss.py
37 lines (32 loc) · 1.18 KB
/
EMLLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F
class Loss(nn.Module):
def __init__(self):
super(Loss, self).__init__()
self.eps = 1e-6
def KL_loss(self, input, target):
input = input / input.sum()
target = target / target.sum()
loss = (target * torch.log(target/(input+self.eps) + self.eps)).sum()
return loss
def CC_loss(self, input, target):
input = (input - input.mean()) / input.std()
target = (target - target.mean()) / target.std()
loss = (input * target).sum() / (torch.sqrt((input*input).sum() * (target * target).sum()))
loss = 1 - loss
return loss
def NSS_loss(self, input, target):
ref = (target - target.mean()) / target.std()
input = (input - input.mean()) / input.std()
loss = (ref*target - input*target).sum() / target.sum()
return loss
def forward(self, input, fix, smap):
kl = 0
cc = 0
nss = 0
for p, f, s in zip(input, fix, smap):
kl += self.KL_loss(p, s)
cc += self.CC_loss(p, s)
nss += self.NSS_loss(p, f)
return (kl + cc + nss) / input.size(0)