-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
44 changed files
with
582 additions
and
1,700 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
.ropeproject | ||
./zdpath | ||
./zoutput | ||
/zdpath | ||
/zoutput | ||
tests/__pycache__/ | ||
*.pyc | ||
.vscode/ | ||
domainlab/zdata/pacs | ||
/data/ | ||
/.snakemake/ | ||
/dist | ||
/domainlab.egg-info | ||
/runs | ||
/slurm_errors.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
use random start to generate adversarial images | ||
""" | ||
import torch | ||
from torch import autograd | ||
from torch.nn import functional as F | ||
from domainlab.algos.trainers.train_basic import TrainerBasic | ||
|
||
|
||
class TrainerIRM(TrainerBasic): | ||
""" | ||
IRMv1 split a minibatch into half, and use an unbiased estimate of the | ||
squared gradient norm via inner product | ||
$$\\delta_{w|w=1} \\ell(w\\dot \\Phi(X^{e, i}), Y^{e, i})$$ | ||
of dimension dim(Grad) | ||
with | ||
$$\\delta_{w|w=1} \\ell(w\\dot \\Phi(X^{e, j}), Y^{e, j})$$ | ||
of dimension dim(Grad) | ||
For more details, see section 3.2 and Appendix D of : | ||
Arjovsky et al., “Invariant Risk Minimization.” | ||
""" | ||
def tr_epoch(self, epoch): | ||
list_loaders = list(self.dict_loader_tr.values()) | ||
loaders_zip = zip(*list_loaders) | ||
self.model.train() | ||
self.epo_loss_tr = 0 | ||
|
||
for ind_batch, tuple_data_domains_batch in enumerate(loaders_zip): | ||
self.optimizer.zero_grad() | ||
list_domain_loss_erm = [] | ||
list_domain_reg = [] | ||
for batch_domain_e in tuple_data_domains_batch: | ||
tensor_x, tensor_y, tensor_d, *others = batch_domain_e | ||
tensor_x, tensor_y, tensor_d = \ | ||
tensor_x.to(self.device), tensor_y.to(self.device), \ | ||
tensor_d.to(self.device) | ||
list_domain_loss_erm.append( | ||
self.model.cal_task_loss(tensor_x, tensor_y)) | ||
list_1ele_loss_irm, _ = \ | ||
self._cal_reg_loss(tensor_x, tensor_y, tensor_d, others) | ||
list_domain_reg += list_1ele_loss_irm | ||
loss = torch.sum(torch.stack(list_domain_loss_erm)) + \ | ||
self.aconf.gamma_reg * torch.sum(torch.stack(list_domain_reg)) | ||
loss.backward() | ||
self.optimizer.step() | ||
self.epo_loss_tr += loss.detach().item() | ||
self.after_batch(epoch, ind_batch) | ||
|
||
flag_stop = self.observer.update(epoch) # notify observer | ||
return flag_stop | ||
|
||
def _cal_phi(self, tensor_x): | ||
logits = self.model.cal_logit_y(tensor_x) | ||
return logits | ||
|
||
def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): | ||
""" | ||
Let trainer behave like a model, so that other trainer could use it | ||
""" | ||
_ = tensor_d | ||
_ = others | ||
y = tensor_y | ||
phi = self._cal_phi(tensor_x) | ||
dummy_w_scale = torch.tensor(1.).to(tensor_x.device).requires_grad_() | ||
loss_1 = F.cross_entropy(phi[::2] * dummy_w_scale, y[::2]) | ||
loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) | ||
grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] | ||
grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] | ||
loss_irm = torch.sum(grad_1 * grad_2) | ||
return [loss_irm], [self.aconf.gamma_reg] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.