-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathhard_triplet_loss.py
143 lines (106 loc) · 5.42 KB
/
hard_triplet_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import torch.nn as nn
import torch.nn.functional as F
class HardTripletLoss(nn.Module):
"""Hard/Hardest Triplet Loss
(pytorch implementation of https://omoindrot.github.io/triplet-loss)
For each anchor, we get the hardest positive and hardest negative to form a triplet.
"""
def __init__(self, margin=0.1, hardest=False, squared=False):
"""
Args:
margin: margin for triplet loss
hardest: If true, loss is considered only hardest triplets.
squared: If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
"""
super(HardTripletLoss, self).__init__()
self.margin = margin
self.hardest = hardest
self.squared = squared
def forward(self, embeddings, labels):
"""
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
Returns:
triplet_loss: scalar tensor containing the triplet loss
"""
pairwise_dist = _pairwise_distance(embeddings, squared=self.squared)
if self.hardest:
# Get the hardest positive pairs
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
valid_positive_dist = pairwise_dist * mask_anchor_positive
hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)
# Get the hardest negative pairs
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
1.0 - mask_anchor_negative)
hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + 0.1)
triplet_loss = torch.mean(triplet_loss)
else:
anc_pos_dist = pairwise_dist.unsqueeze(dim=2)
anc_neg_dist = pairwise_dist.unsqueeze(dim=1)
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
# triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
# and the 2nd (batch_size, 1, batch_size)
loss = anc_pos_dist - anc_neg_dist + self.margin
mask = _get_triplet_mask(labels).float()
triplet_loss = loss * mask
# Remove negative losses (i.e. the easy triplets)
triplet_loss = F.relu(triplet_loss)
# Count number of hard triplets (where triplet_loss > 0)
hard_triplets = torch.gt(triplet_loss, 1e-16).float()
num_hard_triplets = torch.sum(hard_triplets)
triplet_loss = torch.sum(triplet_loss) / (num_hard_triplets + 1e-16)
return triplet_loss
def _pairwise_distance(x, squared=False, eps=1e-16):
# Compute the 2D matrix of distances between all the embeddings.
cor_mat = torch.matmul(x, x.t())
norm_mat = cor_mat.diag()
distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
distances = F.relu(distances)
if not squared:
mask = torch.eq(distances, 0.0).float()
distances = distances + mask * eps
distances = torch.sqrt(distances)
distances = distances * (1.0 - mask)
return distances
def _get_anchor_positive_triplet_mask(labels):
# Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1
# Check if labels[i] == labels[j]
labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
mask = indices_not_equal * labels_equal
return mask
def _get_anchor_negative_triplet_mask(labels):
# Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
# Check if labels[i] != labels[k]
labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
mask = labels_equal ^ 1
return mask
def _get_triplet_mask(labels):
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
A triplet (i, j, k) is valid if:
- i, j, k are distinct
- labels[i] == labels[j] and labels[i] != labels[k]
"""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Check that i, j and k are distinct
indices_not_same = torch.eye(labels.shape[0]).to(device).byte() ^ 1
i_not_equal_j = torch.unsqueeze(indices_not_same, 2)
i_not_equal_k = torch.unsqueeze(indices_not_same, 1)
j_not_equal_k = torch.unsqueeze(indices_not_same, 0)
distinct_indices = i_not_equal_j * i_not_equal_k * j_not_equal_k
# Check if labels[i] == labels[j] and labels[i] != labels[k]
label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1))
i_equal_j = torch.unsqueeze(label_equal, 2)
i_equal_k = torch.unsqueeze(label_equal, 1)
valid_labels = i_equal_j * (i_equal_k ^ 1)
mask = distinct_indices * valid_labels # Combine the two masks
return mask