diff --git a/domainlab/algos/trainers/train_irm.py b/domainlab/algos/trainers/train_irm.py index 0da2d0bde..90b16396e 100644 --- a/domainlab/algos/trainers/train_irm.py +++ b/domainlab/algos/trainers/train_irm.py @@ -66,5 +66,6 @@ def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): loss_2 = F.cross_entropy(phi[1::2] * dummy_w_scale, y[1::2]) grad_1 = autograd.grad(loss_1, [dummy_w_scale], create_graph=True)[0] grad_2 = autograd.grad(loss_2, [dummy_w_scale], create_graph=True)[0] - loss_irm = torch.sum(grad_1 * grad_2) - return [loss_irm], [self.aconf.gamma_reg] + loss_irm_scalar = torch.sum(grad_1 * grad_2) # scalar + loss_irm_tensor = loss_irm_scalar.expand(tensor_x.shape[0]) + return [loss_irm_tensor], [self.aconf.gamma_reg] diff --git a/domainlab/models/model_erm.py b/domainlab/models/model_erm.py index 6ee9c23f9..4ccec7a50 100644 --- a/domainlab/models/model_erm.py +++ b/domainlab/models/model_erm.py @@ -10,7 +10,6 @@ except: backpack = None - def mk_erm(parent_class=AModelClassif, **kwargs): """ Instantiate a Deepall (ERM) model @@ -53,4 +52,30 @@ def convert4backpack(self): """ self._net_invar_feat = extend(self._net_invar_feat, use_converter=True) self.net_classifier = extend(self.net_classifier, use_converter=True) + + def hyper_update(self, epoch, fun_scheduler): # pylint: disable=unused-argument + """ + Method necessary to combine with hyperparameter scheduler + + :param epoch: + :param fun_scheduler: + """ + + def hyper_init(self, functor_scheduler, trainer=None): + """ + initiate a scheduler object via class name and things inside this model + + :param functor_scheduler: the class name of the scheduler + """ + return functor_scheduler( + trainer=trainer + ) + + @property + def list_str_multiplier_na(self): + """ + list of multipliers which match the order in cal_reg_loss + """ + return [] + return ModelERM diff --git a/tests/test_irm.py b/tests/test_irm.py index 5ed8b4ceb..235b9e4ce 100644 --- a/tests/test_irm.py +++ b/tests/test_irm.py @@ -13,6 +13,17 @@ def test_irm(): utils_test_algo(args) +def test_irm_scheduler(): + """ + train with Invariant Risk Minimization + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=hyperscheduler_irm --nname=alexnet" + utils_test_algo(args) + + + + def test_irm_mnist(): """ train with Invariant Risk Minimization