Skip to content

shahtalebi/SAND-mask

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SAND-mask Repository

This repo is the code release for "SAND-mask: An Enhanced Gradient Masking Strategy for the Discovery of Invariances in Domain Generalization".

Forked from DomainBed

This project is mainly developed on top of the DomainBed repository, which is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization.

Published results

Agnostic table Oracle table Spirals table

Available algorithms

The currently available algorithms are:


  • Learning Explanations that are Hard to Vary (AND-Mask, Parascandolo et al., 2020)
  • SAND-mask: An Enhanced Gradient Masking Strategy for the Discovery of Invariances in Domain Generalization (SAND-Mask)

Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks (He et al., 2015) and the hyper-parameter grids described here.

Available datasets

The currently available datasets are:

Send us a PR to add your dataset! Any custom image dataset with folder structure dataset/domain/class/image.xyz is readily usable. While we include some datasets from the WILDS project, please use their official code if you wish to participate in their leaderboard.

Available model selection criteria

Model selection criteria differ in what data is used to choose the best hyper-parameters for a given model:

  • IIDAccuracySelectionMethod: A random subset from the data of the training domains.
  • LeaveOneOutSelectionMethod: A random subset from the data of a held-out (not training, not testing) domain.
  • OracleSelectionMethod: A random subset from the data of the test domain.

Quick start

Download the datasets:

python -m domainbed.scripts.download \
       --data_dir=/my/datasets/path

Train a model:

python -m domainbed.scripts.train\
       --algorithm SANDMask\
       --dataset Spirals\
       --test_env 0

Launch a sweep:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/datasets/path\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py. At the time of writing, the entire sweep trains tens of thousands of models (all algorithms x all datasets x 3 independent trials x 20 random hyper-parameter choices). You can pass arguments to make the sweep smaller:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/datasets/path\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher\
       --algorithms SANDMask\
       --datasets Spirals\
       --n_hparams 20\
       --n_trials 3

After all jobs have either succeeded or failed, you can delete the data from failed jobs with python -m domainbed.scripts.sweep delete_incomplete and then re-launch them by running python -m domainbed.scripts.sweep launch again. Specify the same command-line arguments in all calls to sweep as you did the first time; this is how the sweep script knows which jobs were launched originally.

To view the results of your sweep:

python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path

Running unit tests

DomainBed includes some unit tests and end-to-end tests. While not exhaustive, but they are a good sanity-check. To run the tests:

python -m unittest discover

By default, this only runs tests which don't depend on a dataset directory. To run those tests as well:

DATA_DIR=/my/datasets/path python -m unittest discover

License

This source code is released under the MIT license, included here.

About

No description, website, or topics provided.

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published