-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
933 changed files
with
36,150 additions
and
81,464 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,4 @@ | |
tests/__pycache__/ | ||
*.pyc | ||
.vscode/ | ||
.snakemake/ | ||
data/pacs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.5.0 # Use the specific version of the repo | ||
hooks: | ||
- id: trailing-whitespace | ||
- id: end-of-file-fixer | ||
- id: check-yaml | ||
- repo: https://github.com/pycqa/flake8 | ||
rev: 7.0.0 | ||
hooks: | ||
- id: flake8 | ||
- repo: https://github.com/PyCQA/isort | ||
rev: 5.13.2 | ||
hooks: | ||
- id: isort | ||
- repo: https://github.com/psf/black | ||
rev: 23.12.1 | ||
hooks: | ||
- id: black |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# include data | ||
recursive-include zdata * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,75 +1,103 @@ | ||
# DomainLab: Playground for Domain Generalization | ||
# DomainLab: modular python package for training domain invariant neural networks | ||
|
||
![GH Actions CI ](https://github.com/marrlab/DomainLab/actions/workflows/ci.yml/badge.svg) | ||
![GH Actions CI ](https://github.com/marrlab/DomainLab/actions/workflows/ci.yml/badge.svg?branch=master) | ||
[![codecov](https://codecov.io/gh/marrlab/DomainLab/branch/master/graph/badge.svg)](https://app.codecov.io/gh/marrlab/DomainLab) | ||
[![Codacy Badge](https://app.codacy.com/project/badge/Grade/bc22a1f9afb742efb02b87284e04dc86)](https://www.codacy.com/gh/marrlab/DomainLab/dashboard) | ||
[![Documentation](https://img.shields.io/badge/Documentation-Here)](https://marrlab.github.io/DomainLab/) | ||
[![pages-build-deployment](https://github.com/marrlab/DomainLab/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/marrlab/DomainLab/actions/workflows/pages/pages-build-deployment) | ||
|
||
## Domain Generalization and DomainLab | ||
## Distribution shifts, domain generalization and DomainLab | ||
|
||
Domain Generalization aims at learning domain invariant features by utilizing data from multiple domains so the learned feature can generalize to new unseen domains. | ||
Neural networks trained using data from a specific distribution (domain) usually fails to generalize to novel distributions (domains). Domain generalization aims at learning domain invariant features by utilizing data from multiple domains (data sites, corhorts, batches, vendors) so the learned feature can generalize to new unseen domains (distributions). | ||
|
||
## Why a dedicated package | ||
DomainLab is a software platform with state-of-the-art domain generalization algorithms implemented, designed by maximal decoupling of different software components thus enhances maximal code reuse. | ||
|
||
Domain generalization algorithm try to learn domain invariant features by adding regularization upon the ERM (Emperical Risk Minimization) loss. A typical setting of evaluating domain generalization algorithm is the so called leave-one-domain-out scheme, where one dataset is collected from each distribution. Each time, one dataset/domain is left as test-set to estimate the generalization performance of a model trained upon the rest of domains/datasets. | ||
### DomainLab | ||
DomainLab decouples the following concepts or objects: | ||
- task $M$: In DomainLab, a task is a container for datasets from different domains. (e.g. from distribution $D_1$ and $D_2$). Task offer a static protocol to evaluate the generalization performance of a neural network: which dataset(s) is used for training, wich dataset(s) used for testing. | ||
- neural network: a map $\phi$ from the input data to the feature space and a map $\varphi$ from feature space to output $\hat{y}$ (e.g. decision variable). | ||
- model: structural risk in the form of $\ell() + \mu R()$ where | ||
- $\ell(Y, \hat{y}=\varphi(\phi(X)))$ is the task specific empirical loss (e.g. cross entropy for classification task). | ||
- $R(\phi(X))$ is the penalty loss to boost domain invariant feature extraction using $\phi$. | ||
- $\mu$ is the corresponding multiplier to each penalty function factor. | ||
- trainer: an object that guides the data flow to model and append further domain invariant losses | ||
like inter-domain feature alignment. | ||
|
||
We offer detailed documentation on how these models and trainers work in our documentation page: https://marrlab.github.io/DomainLab/ | ||
|
||
Once you came across a claim, that a domain generalization algorithm A can generate a "better" model h upon some datasets D with "better" performance compared to other algorithms, have you ever wondered: | ||
DomainLab makes it possible to combine models with models, trainers with models, and trainers with trainers in a decorator pattern like line of code `Trainer A(Trainer B(Model C(Model D(network E), network E, network F)))` which correspond to $\ell() + \mu_a R_a() + \mu_b R_b + \mu_c R_c() + \mu_d R_d()$, where Model C and Model D share neural network E, but Model C has an extra neural network F. All models share the same neural network for feature extraction, but can have different auxilliary networks for $R()$. | ||
|
||
- Is this mostly attributed to a more "powerful" neural network architecture of model A compared to others? What will happen if I change the backbone neural network of algorithm A from ResNet to AlexNet? | ||
- Is this mostly attributed the protocol of estimating the generalization performance? e.g. dataset split, Will this algorithm "work" for my datasets? | ||
- Is this mostly attributed to the "clever" regularization algorithm or a special loss function A has used for the neural network? | ||
<div style="align: center; text-align:center;"> | ||
<figure> | ||
<img src="https://github.com/marrlab/DomainLab/blob/master/docs/figs/invarfeat4dg.png?raw=true" style="width:300px;"/> | ||
</figure> | ||
</div> | ||
|
||
To maximally decouple these attributing factors, DomainLab was implemented with software design patterns, where | ||
|
||
- Domain generalization algorithms was implemented in a way that keeps the underlying neural network architecture transparent, i.e. the concrete neural network architecture can be replaced like a plugin through specifying a custom neural network architecture implemented in a python file. See [Specify Custom Neural Networks for an algorithm](./docs/doc_custom_nn.md) | ||
## Getting started | ||
|
||
- To evaluate a domain generalization algorithm's performance, the user can specify a "Task" in the form of custom python file and feed into the command line argument, thus it is at the user's discretion on how to evaluate an algorithm, so that all domain generalization algorithms could be compared fairly. See [Task Specification](./docs/doc_tasks.md). | ||
### Installation | ||
For development version in Github, see [Installation and Dependencies handling](./docs/doc_install.md) | ||
|
||
- To simply test an algorithm's performance, there is no need to change any code inside this repository, the user only need to extend this repository to fit their custom need. | ||
We also offer a PyPI version here https://pypi.org/project/domainlab/ which one could install via `pip install domainlab` and it is recommended to create a virtual environment for it. | ||
|
||
## Getting started | ||
### Installation | ||
### Task specification | ||
We offer various ways for the user to specify a scenario to evaluate the generalization performance via training on a limited number of datasets. See detail in | ||
[Task Specification](./docs/doc_tasks.md) | ||
|
||
- Install via python-poetry: | ||
Read the python-poetry documentation https://python-poetry.org/ and use the configuration file in this repository. | ||
### Example and usage | ||
|
||
- **Or** only install dependencies via pip | ||
Suppose you have cloned the repository and have changed directory to the cloned repository. | ||
#### Either clone this repo and use command line | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
`python main_out.py -c ./examples/conf/vlcs_diva_mldg_dial.yaml` | ||
where the configuration file below can be downloaded [here](https://raw.githubusercontent.com/marrlab/DomainLab/master/examples/conf/vlcs_diva_mldg_dial.yaml) | ||
``` | ||
te_d: caltech # domain name of test domain | ||
tpath: examples/tasks/task_vlcs.py # python file path to specify the task | ||
bs: 2 # batch size | ||
model: dann_diva # combine model DANN with DIVA | ||
epos: 1 # number of epochs | ||
trainer: mldg_dial # combine trainer MLDG and DIAL | ||
gamma_y: 700000.0 # hyperparameter of diva | ||
gamma_d: 100000.0 # hyperparameter of diva | ||
npath: examples/nets/resnet.py # neural network for class classification | ||
npath_dom: examples/nets/resnet.py # neural network for domain classification | ||
``` | ||
See details in [Command line usage](./docs/doc_usage_cmd.md) | ||
|
||
### Basic usage | ||
Suppose you have cloned the repository and the dependencies ready, change directory to the repository: | ||
DomainLab comes with some minimal toy-dataset to test its basis functionality. To train a domain generalization model with a user-specified task, one can execute a command similar to the following. | ||
#### or Programm against DomainLab API | ||
|
||
```bash | ||
python main_out.py --te_d=caltech --tpath=examples/tasks/task_vlcs.py --debug --bs=2 --aname=diva --gamma_y=7e5 --gamma_d=1e5 --nname=alexnet --nname_dom=conv_bn_pool_2 | ||
``` | ||
See example here: [Transformer as feature extractor, decorate JIGEN with DANN, training using MLDG decorated by DIAL](https://github.com/marrlab/DomainLab/blob/master/examples/api/jigen_dann_transformer.py) | ||
|
||
where `--tpath` specifies the path of a user specified python file which defines the domain generalization task, see Example in [Task Specification](./docs/doc_tasks.md). `--aname` specifies which algorithm to use, see [Available Algorithms](./docs/doc_algos.md), `--bs` specifies the batch size, `--debug` restrain only running for 2 epochs and save results with prefix 'debug'. For DIVA, the hyper-parameters include `--gamma_y=7e5` which is the relative weight of ERM loss compared to ELBO loss, and `--gamma_d=1e5`, which is the relative weight of domain classification loss compared to ELBO loss. | ||
`--nname` is to specify which neural network to use for feature extraction for classification, `--nname_dom` is to specify which neural network to use for feature extraction of domains. | ||
For usage of other arguments, check with | ||
|
||
```bash | ||
python main_out.py --help | ||
``` | ||
### Benchmark different methods | ||
DomainLab provides a powerful benchmark functionality. | ||
To benchmark several algorithms(combination of neural networks, models, trainers and associated hyperparameters), a single line command along with a benchmark configuration files is sufficient. See details in [benchmarks documentation and tutorial](./docs/doc_benchmark.md) | ||
|
||
See also [Examples](./docs/doc_examples.md). | ||
One could simply run | ||
`bash run_benchmark_slurm.sh your_benchmark_configuration.yaml` to launch different experiments with specified configuraiton. | ||
|
||
### Output structure (results storage) and Performance Measure | ||
[Output structure and Performance Measure](./docs/doc_output.md) | ||
|
||
## Custom Usage | ||
For example, the following result (without any augmentation like flip) is for PACS dataset using ResNet. | ||
|
||
### Define your task | ||
Do you have your own data that comes from different domains? Create a task for your data and benchmark different domain generlization algorithms according to the following example. See | ||
[Task Specification](./docs/doc_tasks.md) | ||
<div style="align: center; text-align:center;"> | ||
<figure> | ||
<img src="https://github.com/marrlab/DomainLab/blob/master/docs/figs/stochastic_variation_two_rows.png?raw=true" style="width:800px;"/> | ||
<div class="caption" style="align: center; text-align:center;"> | ||
<figcaption>Benchmark results plot generated from DomainLab, where each rectangle represent one model trainer combination, each bar inside the rectangle represent a unique hyperparameter index associated with that method combination, each dot represent a random seeds.</figcaption> | ||
</div> | ||
</figure> | ||
</div> | ||
|
||
### Custom Neural network | ||
This library decouples the concept of algorithm (model) and neural network architecture where the user could plugin different neural network architectures for the same algorithm. See | ||
[Specify Custom Neural Networks for an algorithm](./docs/doc_custom_nn.md) | ||
|
||
## Software Design Pattern, Extend or Contribution, Credits | ||
[Extend or Contibute](./docs/doc_extend_contribute.md) | ||
### Temporary citation | ||
|
||
```bibtex | ||
@manual{domainlab, | ||
title={{DomainLab: modular python package for training domain invariant neural networks}}, | ||
author={{Xudong Sun, et.al.}}, | ||
organization={{Institute of AI for Health}}, | ||
year={2023}, | ||
url={https://github.com/marrlab/DomainLab}, | ||
note={temporary citation for domainlab} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,23 @@ | ||
#!/bin/bash -x -v | ||
set -e # exit upon first error | ||
starttime=`date +%s` | ||
|
||
# run examples | ||
bash -x -v ci_run_examples.sh | ||
|
||
# run test | ||
sh ci_pytest_cov.sh | ||
|
||
# run benchmark | ||
./run_benchmark_standalone.sh examples/benchmark/demo_benchmark.yaml | ||
|
||
# update documentation | ||
# if git status | grep -q 'master'; then | ||
# echo "in master branch" | ||
sh gen_doc.sh | ||
# fi | ||
|
||
endtime=`date +%s` | ||
runtime=$((endtime-starttime)) | ||
echo "total time used:" | ||
echo "$runtime" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
sed -n '/```shell/,/```/ p' docs/doc_benchmark.md | sed '/^```/ d' > ./sh_temp_benchmark.sh | ||
bash -x -v -e sh_temp_benchmark.sh | ||
echo "benchmark done" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
#!/bin/bash | ||
# export CUDA_VISIBLE_DEVICES="" | ||
python -m pytest --cov=domainlab | ||
export CUDA_VISIBLE_DEVICES="" | ||
# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error | ||
# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring | ||
# --cov-report term-missing to show in console file wise coverage and lines missing | ||
python -m pytest --cov=domainlab --cov-report html |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,37 @@ | ||
#!/bin/bash -x -v | ||
set -e # exit upon first error | ||
sed 's/`//g' docs/doc_examples.md > ./sh_temp.sh | ||
bash -x -v -e sh_temp.sh | ||
# >> append content | ||
# > erase original content | ||
|
||
# echo "#!/bin/bash -x -v" > sh_temp_example.sh | ||
sed -n '/```shell/,/```/ p' docs/doc_examples.md | sed '/^```/ d' >> ./sh_temp_example.sh | ||
split -l 5 sh_temp_example.sh sh_example_split | ||
for file in sh_example_split*; | ||
do (echo "#!/bin/bash -x -v" > "$file"_exe && cat "$file" >> "$file"_exe && bash -x -v "$file"_exe && rm -r zoutput); | ||
done | ||
# bash -x -v -e sh_temp_example.sh | ||
echo "general examples done" | ||
|
||
echo "#!/bin/bash -x -v" > sh_temp_mnist.sh | ||
sed -n '/```shell/,/```/ p' docs/doc_MNIST_classification.md | sed '/^```/ d' >> ./sh_temp_mnist.sh | ||
bash -x -v -e sh_temp_mnist.sh | ||
echo "mnist example done" | ||
|
||
echo "#!/bin/bash -x -v" > sh_temp_nn.sh | ||
sed -n '/```shell/,/```/ p' docs/doc_custom_nn.md | sed '/^```/ d' >> ./sh_temp_nn.sh | ||
bash -x -v -e sh_temp_nn.sh | ||
echo "arbitrary nn done" | ||
|
||
echo "#!/bin/bash -x -v" > sh_temp_task.sh | ||
sed -n '/```shell/,/```/ p' docs/doc_tasks.md | sed '/^```/ d' >> ./sh_temp_task.sh | ||
bash -x -v -e sh_temp_task.sh | ||
echo "task done" | ||
|
||
echo "#!/bin/bash -x -v" > sh_temp_readme.sh | ||
sed -n '/```shell/,/```/ p' README.md | sed '/^```/ d' >> ./sh_temp_readme.sh | ||
bash -x -v -e sh_temp_readme.sh | ||
echo "read me done" | ||
|
||
echo "#!/bin/bash -x -v" > sh_temp_extend.sh | ||
sed -n '/```shell/,/```/ p' docs/doc_extend_contribute.md | sed '/^```/ d' >> ./sh_temp_extend.sh | ||
bash -x -v -e sh_temp_extend.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
Hello World | ||
Hello World |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,10 @@ | ||
|
||
1. The h5py files was pre-read using cv2, so it is BGR channel order. | ||
2. This benckmark is sensitive to the different train val splits, so please use this train val splits for the fair comparisons. | ||
2. This benckmark is sensitive to the different train val splits, so please use this train val splits for the fair comparisons. | ||
|
||
https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd | ||
|
||
[https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd](https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd) | ||
|
||
|
||
https://github.com/facebookresearch/DomainBed/blob/4294ec699df761b46a1505734f6be16ef009cad9/domainbed/scripts/download.py#L29 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"this script can be used to download the pacs dataset" | ||
import os | ||
import tarfile | ||
from zipfile import ZipFile | ||
|
||
import gdown | ||
|
||
|
||
def stage_path(data_dir, name): | ||
""" | ||
creates the path to data_dir/name | ||
if it does not exist already | ||
""" | ||
full_path = os.path.join(data_dir, name) | ||
|
||
if not os.path.exists(full_path): | ||
os.makedirs(full_path) | ||
|
||
return full_path | ||
|
||
|
||
def download_and_extract(url, dst, remove=True): | ||
""" | ||
downloads and extracts the data behind the url | ||
and saves it at dst | ||
""" | ||
gdown.download(url, dst, quiet=False) | ||
|
||
if dst.endswith(".tar.gz"): | ||
with open(dst, "r:gz") as tar: | ||
tar.extractall(os.path.dirname(dst)) | ||
tar.close() | ||
|
||
if dst.endswith(".tar"): | ||
with open(dst, "r:") as tar: | ||
tar.extractall(os.path.dirname(dst)) | ||
tar.close() | ||
|
||
if dst.endswith(".zip"): | ||
zfile = ZipFile(dst, "r") | ||
zfile.extractall(os.path.dirname(dst)) | ||
zfile.close() | ||
|
||
if remove: | ||
os.remove(dst) | ||
|
||
|
||
def download_pacs(data_dir): | ||
""" | ||
download and extract dataset pacs. | ||
Dataset is saved at location data_dir | ||
""" | ||
full_path = stage_path(data_dir, "PACS") | ||
|
||
download_and_extract( | ||
"https://drive.google.com/uc?id=1JFr8f805nMUelQWWmfnJR3y4_SYoN5Pd", | ||
os.path.join(data_dir, "PACS.zip"), | ||
) | ||
|
||
os.rename(os.path.join(data_dir, "kfold"), full_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
download_pacs("../pacs") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,2 @@ | ||
param_index, task, algo, epos, te_d, seed, params, acc, precision, recall, specificity, f1, auroc | ||
param_index, method, algo, epos, te_d, seed, params, acc, precision, recall, specificity, f1, auroc | ||
0, Task1, diva, 2, caltech, 1, "{'gamma_y': 1688039, 'gamma_d': 265711}", 0.75, 0.87068963, 0.5588235, 0.5588235, 0.53100574, 0.8098495 | ||
0, Task1, diva, 2, caltech, 2, "{'gamma_y': 1688039, 'gamma_d': 265711}", 0.75, 0.87068963, 0.5588235, 0.5588235, 0.53100574, 0.8098495 | ||
0, Task2, hduva, 2, caltech, 1, "{'a': 1, 'b': 3}", 0.76, 0.76, 0.49, 0.56, 0.57, 0.76 | ||
0, Task2, hduva, 2, caltech, 2, "{'a': 1, 'b': 3}", 0.76, 0.76, 0.49, 0.56, 0.57, 0.76 | ||
0, Task3, deepall, 2, caltech, 1, "{}", 0.7, 0.65, 0.5, 0.5, 0.5, 0.7 | ||
0, Task3, deepall, 2, caltech, 2, "{}", 0.7, 0.65, 0.5, 0.5, 0.5, 0.7 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.