Skip to content

Commit

Permalink
Merge pull request #869 from marrlab/erm_hyper_init
Browse files Browse the repository at this point in the history
matteo change erm
  • Loading branch information
smilesun authored Sep 16, 2024
2 parents 7494a61 + 47e0ce2 commit 9da2dee
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
5 changes: 3 additions & 2 deletions domainlab/algos/trainers/train_irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2])
grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0]
grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0]
loss_irm = torch.sum(grad_1 * grad_2)
return [loss_irm], [self.aconf.gamma_reg]
loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar
loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0])
return [loss_irm_tensor], [self.aconf.gamma_reg]
27 changes: 26 additions & 1 deletion domainlab/models/model_erm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
except:
backpack = None


def mk_erm(parent_class=AModelClassif, **kwargs):
"""
Instantiate a Deepall (ERM) model
Expand Down Expand Up @@ -53,4 +52,30 @@ def convert4backpack(self):
"""
self._net_invar_feat = extend(self._net_invar_feat, use_converter=True)
self.net_classifier = extend(self.net_classifier, use_converter=True)

def hyper_update(self, epoch, fun_scheduler): # pylint: disable=unused-argument
"""
Method necessary to combine with hyperparameter scheduler
:param epoch:
:param fun_scheduler:
"""

def hyper_init(self, functor_scheduler, trainer=None):
"""
initiate a scheduler object via class name and things inside this model
:param functor_scheduler: the class name of the scheduler
"""
return functor_scheduler(
trainer=trainer
)

@property
def list_str_multiplier_na(self):
"""
list of multipliers which match the order in cal_reg_loss
"""
return []

return ModelERM
11 changes: 11 additions & 0 deletions tests/test_irm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ def test_irm():
utils_test_algo(args)


def test_irm_scheduler():
"""
train with Invariant Risk Minimization
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=hyperscheduler_irm --nname=alexnet"
utils_test_algo(args)




def test_irm_mnist():
"""
train with Invariant Risk Minimization
Expand Down

0 comments on commit 9da2dee

Please sign in to comment.