-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss.py
110 lines (94 loc) · 3.93 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import random
from itertools import combinations
import torch
import torch.nn as nn
import torch.nn.functional as F
class OnlineTripleLoss(nn.Module):
def __init__(self, margin, sampling_strategy="random_sh"):
super(OnlineTripleLoss, self).__init__()
self.margin = margin
self.triplet_selector = NegativeTripletSelector(
margin, sampling_strategy
)
def forward(self, embeddings, labels):
triplets = self.triplet_selector.get_triplets(embeddings, labels)
ap_dists = F.pairwise_distance(
embeddings[triplets[0], :], embeddings[triplets[1], :]
)
an_dists = F.pairwise_distance(
embeddings[triplets[0], :], embeddings[triplets[2], :]
)
loss = F.relu(ap_dists - an_dists + self.margin)
return loss.mean(), len(triplets[0])
class NegativeTripletSelector:
def __init__(self, margin, sampling_strategy="random_sh"):
super(NegativeTripletSelector, self).__init__()
self.margin = margin
self.sampling_strategy = sampling_strategy
def get_triplets(self, embeddings, labels):
distance_matrix = pdist(embeddings, eps=0)
unique_labels, counts = torch.unique(labels, return_counts=True)
triplets_indices = [[] for i in range(3)]
for i, label in enumerate(unique_labels):
label_mask = labels == label
label_indices = torch.where(label_mask)[0]
if label_indices.shape[0] < 2:
continue
negative_indices = torch.where(torch.logical_not(label_mask))[0]
triplet_label_pairs = self.get_one_one_triplets(
label_indices, negative_indices, distance_matrix,
)
triplets_indices[0].extend(triplet_label_pairs[0])
triplets_indices[1].extend(triplet_label_pairs[1])
triplets_indices[2].extend(triplet_label_pairs[2])
return triplets_indices
def get_one_one_triplets(self, pos_labels, negative_indices, dist_mat):
anchor_positives = list(combinations(pos_labels, 2))
triplets_indices = [[] for i in range(3)]
for i, anchor_positive in enumerate(anchor_positives):
anchor_idx = anchor_positive[0]
pos_idx = anchor_positive[1]
ap_dist = dist_mat[anchor_idx, pos_idx]
an_dists = dist_mat[anchor_idx, negative_indices]
if self.sampling_strategy == "random_sh":
neg_list_idx = random_semi_hard_sampling(
ap_dist, an_dists, self.margin
)
elif self.sampling_strategy == "fixed_sh":
neg_list_idx = fixed_semi_hard_sampling(
ap_dist, an_dists, self.margin
)
else:
neg_list_idx = None
if neg_list_idx is not None:
neg_idx = negative_indices[neg_list_idx]
triplets_indices[0].append(anchor_idx)
triplets_indices[1].append(pos_idx)
triplets_indices[2].append(neg_idx)
return triplets_indices
def random_semi_hard_sampling(ap_dist, an_dists, margin):
ap_margin_dist = ap_dist + margin
loss = ap_margin_dist - an_dists
possible_negs = torch.where(loss > 0)[0]
if possible_negs.nelement() != 0:
neg_idx = random.choice(possible_negs)
else:
neg_idx = None
return neg_idx
def fixed_semi_hard_sampling(ap_dist, an_dists, margin):
ap_margin_dist = ap_dist + margin
loss = ap_margin_dist - an_dists
possible_negs = torch.where(loss > 0)[0]
if possible_negs.nelement() != 0:
neg_idx = torch.argmax(loss).item()
else:
neg_idx = None
# neg_idx = torch.argmin(an_dists).item()
return neg_idx
def pdist(vectors, eps):
dist_mat = []
for i in range(len(vectors)):
dist_mat.append(
F.pairwise_distance(vectors[i], vectors, eps=eps).unsqueeze(0)
)
return torch.cat(dist_mat, dim=0)