-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlosses.py
41 lines (31 loc) · 1.12 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn.functional as F
# DCGAN loss
def loss_dcgan_dis(dis_fake, dis_real):
L1 = torch.mean(F.softplus(-dis_real))
L2 = torch.mean(F.softplus(dis_fake))
return L1, L2
def loss_dcgan_gen(dis_fake):
loss = torch.mean(F.softplus(-dis_fake))
return loss
# Hinge Loss
def loss_hinge_dis(dis_fake, dis_real):
loss_real = torch.mean(F.relu(1. - dis_real))
loss_fake = torch.mean(F.relu(1. + dis_fake))
return loss_real, loss_fake
# def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss
# loss = torch.mean(F.relu(1. - dis_real))
# loss += torch.mean(F.relu(1. + dis_fake))
# return loss
def loss_hinge_gen(dis_fake):
loss = -torch.mean(dis_fake)
return loss
def loss_latent(latent_reps, pos, neg):
fake_pos = latent_reps.cuda(0) - pos.cuda(0)
fake_neg = latent_reps.cuda(0) - neg.cuda(0)
loss = torch.mean(torch.maximum(0.5+torch.linalg.norm(fake_neg, dim = -1)- torch.linalg.norm(fake_pos, dim = -1),torch.tensor(0)),dim = 0)
return loss
# Default to hinge loss
generator_loss = loss_hinge_gen
discriminator_loss = loss_hinge_dis
latent_loss = loss_latent