Code accompanying the NeurIPS 2021 article
Training for the Future: A Simple Gradient Interpolation Loss to Generalize Along Time
Anshul Nasery*, Soumyadeep Thakur*, Vihari Piratla, Abir De, Sunita Sarawagi
This repository contains the training and inference code, as well as codebases for different baselines.
The code and instructions to run for all models can be found in src/
. The processed datasets should be downloaded to data/
from this link
- Install
torch>=1.4
and correspondingtorchvision
. Also installtqdm
,numpy
,sklearn
. - Install the POT library following this link.
- Install
pkg-resources==0.0.0
,six==1.12.0
,Pillow==8.1.1
. - Download the datasets into the
data/
directory from from this link
The file src/<<MODEL>/main.py
is usually the entry-point for starting a training and inference job for each <MODEL>
. The standard way to run this file is python3 main.py --data <DS> --model <MODEL> --seed <SEED>
. However, there are minor differences as illustrated in the files src/<MODEL>/README.md
. The results are written to src/<MODEL>/log_<MODEL>_<DS>
for each run.
The directory src/
has four sub-folders, for our method and baselines
-
GI/
main.py
- Entrypoint to the codetrainer_GI.py
- Contains the training algorithm implementationconfig_GI.py
- Contains the hyperparameter configurations for different datasets and modelspreprocess.py
- Can be used to generate the processed datasets from raw files
-
CIDA/
main.py
- Entrypoint to the code, contains dataloader and training algorithm<DS>_models.py
- Contains model definition and hyperparameters for the dataset<DS>
.
-
CDOT/
ot_main.py
- Entrypoint to the code, contains the training algorithm implementation, contains the hyperparameter configurations for different datasets and modelstransport.py
,regularized_OT.py
- Contain implementations of the OT and CDOT algorithms
-
adagraph/
main_all_source.py
- Entrypoint to the codeconfigs/
- Contains hyperparams for various datasetsdataloaders/
- Contains dataloaders for various datasetsmodels/
- Contains model definitions
If you find the paper or the code helpful in your research, consider citing us as
@inproceedings{
nasery2021training,
title={Training for the Future: A Simple Gradient Interpolation Loss to Generalize Along Time},
author={Anshul Nasery and Soumyadeep Thakur and Vihari Piratla and Abir De and Sunita Sarawagi},
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
year={2021},
url={https://openreview.net/forum?id=U7SBcmRf65}
}