Skip to content

Latest commit

 

History

History
34 lines (27 loc) · 1.45 KB

README.md

File metadata and controls

34 lines (27 loc) · 1.45 KB

Deep metric learning and classification with online triplet mining

Pytorch implementation of triplet networks for metric learning

Installation

This package requires Pytorch version 1.4.0 and TorchVision 0.5.0

Features

  • GPU implementation of online triplet loss in a way similar to pytorch loss
  • Implements 1-1 sampling strategy as defined in [1]
  • Random semi-hard and fixed semi-hard sampling
  • UMAP visualization of the results
  • Implementation of training strategy to train a classifier after learning the embeddings.
  • Implementation of stratified sampling strategy for the batches.
  • Implemented on MNIST dataset as an example.

Code Structure

  • networks.py
    • ConvNet class - base network for embedding images in vectors and getting labels
  • loss.py
    • OnlineTripletLoss - triplet loss class for embeddings
    • NegativeTripletSelector - class for selecting the negative sample from the batch based on the sampling strategy.
  • train.py
    • TripletTrainer - class for training the dataset with triplet loss and a classifier after it if required.
  • utils.py
    • make_weights_for_balanced_classes - assign weight to every sample in dataset for batch sampling.
    • save_embedding_umap - save UMAPs of the training set and test set.
  • config.json
    • hyperparameters for the training

References

[1] Theoretical Guarantees of Deep Embedding Losses Under Label Noise