Skip to content

Commit

Permalink
Merge branch 'master' into erm_hyper_init
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Sep 11, 2024
2 parents 151703a + 7494a61 commit 47e0ce2
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 18 deletions.
6 changes: 6 additions & 0 deletions docs/docDIAL.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,9 @@ This procedure yields to the following availability of hyperparameter:
- `--dial_epsilon`: pixel wise threshold to perturb images
- `--gamma_reg`: ? ($\epsilon$ in the paper)
- `--lr`: learning rate ($\alpha$ in the paper)

# Examples

```
python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=dial --nname=conv_bn_pool_2
```
5 changes: 5 additions & 0 deletions docs/docFishr.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ For more details, see the reference below or the domainlab code.



# Examples
```
python main_out.py --te_d=0 --task=mini_vlcs --model=erm --trainer=fishr --nname=alexnet --bs=2 --nocu
```



_Reference:_
Rame, Alexandre, Corentin Dancette, and Matthieu Cord. "Fishr:
Invariant gradient variances for out-of-distribution generalization."
International Conference on Machine Learning. PMLR, 2022.

12 changes: 12 additions & 0 deletions docs/docHDUVA.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ Alternatively, one could use an existing neural network in DomainLab using `nnam

## Hyperparameter for warmup
Finally, the number of epochs for hyper-parameter warm-up can be specified via the argument `warmup`.
## Examples
### use hduva on color mnist, train on 2 domains
```shell
python main_out.py --tr_d 0 1 2 --te_d 3 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
```

### hduva is domain-unsupervised, so it works also with a single domain
```shell
python main_out.py --tr_d 0 --te_d 3 4 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
```



Please cite our paper if you find it useful!
```text
Expand Down
4 changes: 4 additions & 0 deletions docs/docIRM.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@ where $\lambda$ is a hyperparameter that controls the trade-off between the empi
In practice, one could simply divide one mini-batch into two subsets, let $i$ and $j$ to index these two subsets, multiply subset $i$ and subset $j$ forms an unbiased estimation of the L2 norm of gradient.
In detail: the squared gradient norm via inner product between $\nabla_{w|w=1} \ell(w \circ \Phi(X^{(d, i)}), Y^{(d, i)})$ of dimension dim(Grad) with $\nabla_{w|w=1} \ell(w \circ \Phi(X^{(d, j)}), Y^{(d, j)})$ of dimension dim(Grad) For more details, see section 3.2 and Appendix D of : Arjovsky et al., “Invariant Risk Minimization.”

# Examples
```shell
python main_out.py --te_d=0 --task=mnistcolor10 --model=erm --trainer=irm --nname=conv_bn_pool_2

```
10 changes: 0 additions & 10 deletions docs/doc_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@ python main_out.py --te_d 0 1 --tr_d 3 5 --task=mnistcolor10 --debug --bs=2 --mo
python main_out.py --te_d=0 --task=mnistcolor10 --keep_model --model=diva --nname=conv_bn_pool_2 --nname_dom=conv_bn_pool_2 --gamma_y=10e5 --gamma_d=1e5 --gen
```

### use hduva on color mnist, train on 2 domains
```shell
python main_out.py --tr_d 0 1 2 --te_d 3 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
```

### hduva is domain-unsupervised, so it works also with a single domain
```shell
python main_out.py --tr_d 0 --te_d 3 4 --bs=2 --task=mnistcolor10 --model=hduva --nname=conv_bn_pool_2 --gamma_y=7e5 --nname_encoder_x2topic_h=conv_bn_pool_2 --nname_encoder_sandwich_x2h4zd=conv_bn_pool_2
```


## Larger images:

Expand Down
77 changes: 77 additions & 0 deletions domainlab/algos/trainers/train_causIRL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Alex, Xudong
"""
import numpy as np
import torch
from domainlab.algos.trainers.train_basic import TrainerBasic


class TrainerCausalIRL(TrainerBasic):
"""
causal matching
"""
def my_cdist(self, x1, x2):
"""
distance for Gaussian
"""
# along the last dimension
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
# x_2_norm is [batchsize, 1]
# matrix multiplication (2nd, 3rd) and addition to first argument
# X1[batchsize, dimfeat] * X2[dimfeat, batchsize)
# alpha: Scaling factor for the matrix product (default: 1)
# x2_norm.transpose(-2, -1) is row vector
# x_1_norm is column vector
res = torch.addmm(x2_norm.transpose(-2, -1),
x1,
x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
return res.clamp_min_(1e-30)

def gaussian_kernel(self, x, y):
"""
kernel for MMD
"""
gamma=[0.001, 0.01, 0.1, 1, 10, 100, 1000]
dist = self.my_cdist(x, y)
tensor = torch.zeros_like(dist)
for g in gamma:
tensor.add_(torch.exp(dist.mul(-g)))
return tensor

def mmd(self, x, y):
"""
maximum mean discrepancy
"""
kxx = self.gaussian_kernel(x, x).mean()
kyy = self.gaussian_kernel(y, y).mean()
kxy = self.gaussian_kernel(x, y).mean()
return kxx + kyy - 2 * kxy

def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
"""
optimize neural network one step upon a mini-batch of data
"""
self.before_batch(epoch, ind_batch)
tensor_x, tensor_y, tensor_d = (
tensor_x.to(self.device),
tensor_y.to(self.device),
tensor_d.to(self.device),
)
self.optimizer.zero_grad()

features = self.get_model().extract_semantic_feat(tensor_x)

pos_batch_break = np.random.randint(0, tensor_x.shape[0])
first = features[:pos_batch_break]
second = features[pos_batch_break:]
if len(first) > 1 and len(second) > 1:
penalty = torch.nan_to_num(self.mmd(first, second))
else:
penalty = torch.tensor(0)
loss = self.cal_loss(tensor_x, tensor_y, tensor_d, others)
loss = loss + penalty
loss.backward()
self.optimizer.step()
self.after_batch(epoch, ind_batch)
self.counter_batch += 1
5 changes: 4 additions & 1 deletion domainlab/algos/trainers/zoo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.algos.trainers.train_ema import TrainerMA
from domainlab.algos.trainers.train_dial import TrainerDIAL
from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler
from domainlab.algos.trainers.train_hyper_scheduler \
import TrainerHyperScheduler
from domainlab.algos.trainers.train_matchdg import TrainerMatchDG
from domainlab.algos.trainers.train_mldg import TrainerMLDG
from domainlab.algos.trainers.train_fishr import TrainerFishr
from domainlab.algos.trainers.train_irm import TrainerIRM
from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL


class TrainerChainNodeGetter(object):
Expand Down Expand Up @@ -54,6 +56,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
chain = TrainerFishr(chain)
chain = TrainerIRM(chain)
chain = TrainerHyperScheduler(chain)
chain = TrainerCausalIRL(chain)
node = chain.handle(self.request)
head = node
while self._list_str_trainer:
Expand Down
24 changes: 17 additions & 7 deletions scripts/ci_run_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@ set -e # exit upon first error
# >> append content
# > erase original content

# echo "#!/bin/bash -x -v" > sh_temp_example.sh
sed -n '/```shell/,/```/ p' docs/doc_examples.md | sed '/^```/ d' >> ./sh_temp_example.sh
split -l 5 sh_temp_example.sh sh_example_split
for file in sh_example_split*;
do (echo "#!/bin/bash -x -v" > "$file"_exe && cat "$file" >> "$file"_exe && bash -x -v "$file"_exe && rm -r zoutput);

files=("docs/docDIAL.md" "docs/docIRM.md" "docs/doc_examples.md" "docs/docHDUVA.md")

for file in "${files[@]}"
do
echo "Processing $file"
# no need to remove sh_temp_algo.sh since the following line overwrite it each time
echo "#!/bin/bash -x -v" > sh_temp_algo.sh
# remove code marker ```
# we use >> here to append to keep the header #!/bin/bash -x -v
sed -n '/```shell/,/```/ p' $file | sed '/^```/ d' >> ./sh_temp_algo.sh
cat sh_temp_algo.sh
bash -x -v -e sh_temp_algo.sh
# Add your commands to process each file here
echo "finished with $file"
done
# bash -x -v -e sh_temp_example.sh
echo "general examples done"



echo "#!/bin/bash -x -v" > sh_temp_mnist.sh
sed -n '/```shell/,/```/ p' docs/doc_MNIST_classification.md | sed '/^```/ d' >> ./sh_temp_mnist.sh
Expand Down
13 changes: 13 additions & 0 deletions tests/test_causal_irl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
end-end test
"""
from tests.utils_test import utils_test_algo


def test_causal_irl():
"""
causal irl
"""
args = "--te_d 0 --tr_d 3 7 --bs=32 --debug --task=mnistcolor10 \
--model=erm --nname=conv_bn_pool_2 --trainer=causalirl"
utils_test_algo(args)

0 comments on commit 47e0ce2

Please sign in to comment.