diff --git a/docs/docDIAL.md b/docs/docDIAL.md index 8b8111de8..600f8b2cb 100644 --- a/docs/docDIAL.md +++ b/docs/docDIAL.md @@ -72,3 +72,9 @@ This procedure yields to the following availability of hyperparameter: - `--dial_epsilon`: pixel wise threshold to perturb images - `--gamma_reg`: ? ($\epsilon$ in the paper) - `--lr`: learning rate ($\alpha$ in the paper) + +# Examples + +``` +python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=dial --nname=conv_bn_pool_2 +``` diff --git a/docs/docFishr.md b/docs/docFishr.md index 08580d9fe..e2ba4c1b9 100644 --- a/docs/docFishr.md +++ b/docs/docFishr.md @@ -72,6 +72,10 @@ For more details, see the reference below or the domainlab code. +# Examples +``` +python main_out.py --te_d=0 --task=mini_vlcs --model=erm --trainer=fishr --nname=alexnet --bs=2 --nocu +``` @@ -79,3 +83,4 @@ _Reference:_ Rame, Alexandre, Corentin Dancette, and Matthieu Cord. "Fishr: Invariant gradient variances for out-of-distribution generalization." International Conference on Machine Learning. PMLR, 2022. + diff --git a/docs/docHDUVA.md b/docs/docHDUVA.md index 4abcb71e9..47af4d344 100644 --- a/docs/docHDUVA.md +++ b/docs/docHDUVA.md @@ -52,6 +52,18 @@ Alternatively, one could use an existing neural network in DomainLab using `nnam ## Hyperparameter for warmup Finally, the number of epochs for hyper-parameter warm-up can be specified via the argument `warmup`. +## Examples +### use hduva on color mnist, train on 2 domains +```shell +python main_out.py --tr_d 0 1 2 --te_d 3 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 +``` + +### hduva is domain-unsupervised, so it works also with a single domain +```shell +python main_out.py --tr_d 0 --te_d 3 4 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 +``` + + Please cite our paper if you find it useful! ```text diff --git a/docs/docIRM.md b/docs/docIRM.md index 955ff14a3..0b8114de8 100644 --- a/docs/docIRM.md +++ b/docs/docIRM.md @@ -26,4 +26,8 @@ where $\lambda$ is a hyperparameter that controls the trade-off between the empi In practice, one could simply divide one mini-batch into two subsets, let $i$ and $j$ to index these two subsets, multiply subset $i$ and subset $j$ forms an unbiased estimation of the L2 norm of gradient. In detail: the squared gradient norm via inner product between $\nabla_{w|w=1} \ell(w \circ \Phi(X^{(d, i)}), Y^{(d, i)})$ of dimension dim(Grad) with $\nabla_{w|w=1} \ell(w \circ \Phi(X^{(d, j)}), Y^{(d, j)})$ of dimension dim(Grad) For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.” +# Examples +```shell +python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=irm --nname=conv_bn_pool_2 +``` diff --git a/docs/doc_examples.md b/docs/doc_examples.md index 21d0b2eb2..0dddc5f5b 100755 --- a/docs/doc_examples.md +++ b/docs/doc_examples.md @@ -26,16 +26,6 @@ python main_out.py --te_d 0 1 --tr_d 3 5 --task=mnistcolor10 --debug --bs=2 --mo python main_out.py --te_d=0 --task=mnistcolor10 --keep_model --model=diva --nname=conv_bn_pool_2 --nname_dom=conv_bn_pool_2 --gamma_y=10e5 --gamma_d=1e5 --gen ``` -### use hduva on color mnist, train on 2 domains -```shell -python main_out.py --tr_d 0 1 2 --te_d 3 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 -``` - -### hduva is domain-unsupervised, so it works also with a single domain -```shell -python main_out.py --tr_d 0 --te_d 3 4 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2 -``` - ## Larger images: diff --git a/domainlab/algos/trainers/train_causIRL.py b/domainlab/algos/trainers/train_causIRL.py new file mode 100644 index 000000000..085abcbfa --- /dev/null +++ b/domainlab/algos/trainers/train_causIRL.py @@ -0,0 +1,77 @@ +""" +Alex, Xudong +""" +import numpy as np +import torch +from domainlab.algos.trainers.train_basic import TrainerBasic + + +class TrainerCausalIRL(TrainerBasic): + """ + causal matching + """ + def my_cdist(self, x1, x2): + """ + distance for Gaussian + """ + # along the last dimension + x1_norm = x1.pow(2).sum(dim=-1, keepdim=True) + x2_norm = x2.pow(2).sum(dim=-1, keepdim=True) + # x_2_norm is [batchsize, 1] + # matrix multiplication (2nd, 3rd) and addition to first argument + # X1[batchsize, dimfeat] * X2[dimfeat, batchsize) + # alpha: Scaling factor for the matrix product (default: 1) + # x2_norm.transpose(-2, -1) is row vector + # x_1_norm is column vector + res = torch.addmm(x2_norm.transpose(-2, -1), + x1, + x2.transpose(-2, -1), alpha=-2).add_(x1_norm) + return res.clamp_min_(1e-30) + + def gaussian_kernel(self, x, y): + """ + kernel for MMD + """ + gamma=[0.001, 0.01, 0.1, 1, 10, 100, 1000] + dist = self.my_cdist(x, y) + tensor = torch.zeros_like(dist) + for g in gamma: + tensor.add_(torch.exp(dist.mul(-g))) + return tensor + + def mmd(self, x, y): + """ + maximum mean discrepancy + """ + kxx = self.gaussian_kernel(x, x).mean() + kyy = self.gaussian_kernel(y, y).mean() + kxy = self.gaussian_kernel(x, y).mean() + return kxx + kyy - 2 * kxy + + def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch): + """ + optimize neural network one step upon a mini-batch of data + """ + self.before_batch(epoch, ind_batch) + tensor_x, tensor_y, tensor_d = ( + tensor_x.to(self.device), + tensor_y.to(self.device), + tensor_d.to(self.device), + ) + self.optimizer.zero_grad() + + features = self.get_model().extract_semantic_feat(tensor_x) + + pos_batch_break = np.random.randint(0, tensor_x.shape[0]) + first = features[:pos_batch_break] + second = features[pos_batch_break:] + if len(first) > 1 and len(second) > 1: + penalty = torch.nan_to_num(self.mmd(first, second)) + else: + penalty = torch.tensor(0) + loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others) + loss = loss + penalty + loss.backward() + self.optimizer.step() + self.after_batch(epoch, ind_batch) + self.counter_batch += 1 diff --git a/domainlab/algos/trainers/zoo_trainer.py b/domainlab/algos/trainers/zoo_trainer.py index e4a8c7bd5..d1eccb59b 100644 --- a/domainlab/algos/trainers/zoo_trainer.py +++ b/domainlab/algos/trainers/zoo_trainer.py @@ -4,11 +4,13 @@ from domainlab.algos.trainers.train_basic import TrainerBasic from domainlab.algos.trainers.train_ema import TrainerMA from domainlab.algos.trainers.train_dial import TrainerDIAL -from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler +from domainlab.algos.trainers.train_hyper_scheduler \ + import TrainerHyperScheduler from domainlab.algos.trainers.train_matchdg import TrainerMatchDG from domainlab.algos.trainers.train_mldg import TrainerMLDG from domainlab.algos.trainers.train_fishr import TrainerFishr from domainlab.algos.trainers.train_irm import TrainerIRM +from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL class TrainerChainNodeGetter(object): @@ -54,6 +56,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None): chain = TrainerFishr(chain) chain = TrainerIRM(chain) chain = TrainerHyperScheduler(chain) + chain = TrainerCausalIRL(chain) node = chain.handle(self.request) head = node while self._list_str_trainer: diff --git a/scripts/ci_run_examples.sh b/scripts/ci_run_examples.sh index 9f6b4e041..98e0046a6 100644 --- a/scripts/ci_run_examples.sh +++ b/scripts/ci_run_examples.sh @@ -3,14 +3,24 @@ set -e # exit upon first error # >> append content # > erase original content -# echo "#!/bin/bash -x -v" > sh_temp_example.sh -sed -n '/```shell/,/```/ p' docs/doc_examples.md | sed '/^```/ d' >> ./sh_temp_example.sh -split -l 5 sh_temp_example.sh sh_example_split -for file in sh_example_split*; -do (echo "#!/bin/bash -x -v" > "$file"_exe && cat "$file" >> "$file"_exe && bash -x -v "$file"_exe && rm -r zoutput); + +files=("docs/docDIAL.md" "docs/docIRM.md" "docs/doc_examples.md" "docs/docHDUVA.md") + +for file in "${files[@]}" +do +echo "Processing $file" +# no need to remove sh_temp_algo.sh since the following line overwrite it each time +echo "#!/bin/bash -x -v" > sh_temp_algo.sh +# remove code marker ``` +# we use >> here to append to keep the header #!/bin/bash -x -v +sed -n '/```shell/,/```/ p' $file | sed '/^```/ d' >> ./sh_temp_algo.sh +cat sh_temp_algo.sh +bash -x -v -e sh_temp_algo.sh +# Add your commands to process each file here +echo "finished with $file" done -# bash -x -v -e sh_temp_example.sh -echo "general examples done" + + echo "#!/bin/bash -x -v" > sh_temp_mnist.sh sed -n '/```shell/,/```/ p' docs/doc_MNIST_classification.md | sed '/^```/ d' >> ./sh_temp_mnist.sh diff --git a/tests/test_causal_irl.py b/tests/test_causal_irl.py new file mode 100644 index 000000000..5292ef18c --- /dev/null +++ b/tests/test_causal_irl.py @@ -0,0 +1,13 @@ +""" + end-end test +""" +from tests.utils_test import utils_test_algo + + +def test_causal_irl(): + """ + causal irl + """ + args = "--te_d 0 --tr_d 3 7 --bs=32 --debug --task=mnistcolor10 \ + --model=erm --nname=conv_bn_pool_2 --trainer=causalirl" + utils_test_algo(args)