diff --git a/utils.py b/utils.py index 1eb2d977..981f8a0a 100644 --- a/utils.py +++ b/utils.py @@ -162,19 +162,18 @@ def get_triplets(self, embeddings, labels): anchor_positives = list(combinations(label_indices, 2)) # All anchor-positive pairs anchor_positives = np.array(anchor_positives) - ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] - for anchor_positive, ap_distance in zip(anchor_positives, ap_distances): - loss_values = ap_distance - distance_matrix[torch.LongTensor(np.array([anchor_positive[0]])), torch.LongTensor(negative_indices)] + self.margin - loss_values = loss_values.data.cpu().numpy() - hard_negative = self.negative_selection_fn(loss_values) + ap_distances = distance_matrix[anchor_positives[:, 0], anchor_positives[:, 1]] + self.margin + idxs = np.ix_(anchor_positives[:, 0], negative_indices) + loss_values = ap_distances.unsqueeze(dim=1) - distance_matrix[idxs] + loss_values = loss_values.data.cpu().numpy() + for i, loss_val in enumerate(loss_values): + hard_negative = self.negative_selection_fn(loss_val) if hard_negative is not None: hard_negative = negative_indices[hard_negative] - triplets.append([anchor_positive[0], anchor_positive[1], hard_negative]) + triplets.append([anchor_positives[i][0], anchor_positives[i][1], hard_negative]) if len(triplets) == 0: - triplets.append([anchor_positive[0], anchor_positive[1], negative_indices[0]]) - - triplets = np.array(triplets) + triplets.append([anchor_positives[-1][0], anchor_positives[-1][1], negative_indices[0]]) return torch.LongTensor(triplets)