diff --git a/domainlab/algos/trainers/train_fbopt.py b/domainlab/algos/trainers/train_fbopt.py index d4ea27af6..8bd2c987b 100644 --- a/domainlab/algos/trainers/train_fbopt.py +++ b/domainlab/algos/trainers/train_fbopt.py @@ -96,7 +96,9 @@ def eval_r_loss(self): for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr_no_drop): tensor_x, vec_y, vec_d = \ tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device) - b_reg_loss = temp_model.cal_reg_loss(tensor_x, vec_y, vec_d).sum() + tuple_reg_loss = temp_model.cal_reg_loss(tensor_x, vec_y, vec_d) + b_reg_loss = tuple_reg_loss[0][0] # FIXME: this only works when scalar multiplier + b_reg_loss = b_reg_loss.sum().item() b_task_loss = temp_model.cal_task_loss(tensor_x, vec_y).sum() # sum will kill the dimension of the mini batch epo_reg_loss += b_reg_loss diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index 151809d10..c1a52d4d3 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -24,8 +24,18 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): """ calculate the loss """ - return self.cal_task_loss(tensor_x, tensor_y) + \ - self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + list_loss, list_multiplier = self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + loss_reg = self.inner_product(list_loss, list_multiplier) + return self.cal_task_loss(tensor_x, tensor_y) + loss_reg + + + def inner_product(self, list_loss_scalar, list_multiplier): + """ + compute inner product between list of scalar loss and multiplier + """ + list_tuple = zip(list_loss_scalar, list_multiplier) + rst = [mtuple[0]*mtuple[1] for mtuple in list_tuple] + return sum(rst) # FIXME: is "sum" safe to pytorch? @abc.abstractmethod def cal_task_loss(self, tensor_x, tensor_y): diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index b29e9fea1..2e3d20b13 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -190,4 +190,7 @@ def cal_loss_gen_adv(self, x_natural, x_adv, vec_y): return loss_adv_gen + loss_adv_gen_task.sum() def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): - return 0 + """ + for ERM to adapt to the interface of other regularized learners + """ + return [0], [0] diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index e48a6b5f4..2dd14e381 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -84,5 +84,5 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): AutoGradFunReverseMultiply.apply(feat, self.alpha)) _, d_target = tensor_d.max(dim=1) lc_d = F.cross_entropy(logit_d, d_target, reduction="none") - return self.alpha*lc_d + return [lc_d], [self.alpha] return ModelDAN diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index 635d06839..906233678 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -93,10 +93,6 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): _, d_target = tensor_d.max(dim=1) lc_d = F.cross_entropy(logit_d, d_target, reduction="none") - loss_reg = loss_recon_x \ - - self.beta_d * zd_p_minus_zd_q \ - - self.beta_x * zx_p_minus_zx_q \ - - self.beta_y * zy_p_minus_zy_q \ - + self.gamma_d * lc_d - return loss_reg + return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \ + [1.0, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d] return ModelDIVA diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index ba5074b77..3c7b3ccc7 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -121,12 +121,8 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): # reconstruction z_concat = self.decoder.concat_ytdx(zy_q, topic_q, zd_q, zx_q) loss_recon_x, _, _ = self.decoder(z_concat, tensor_x) - batch_loss = loss_recon_x \ - - self.beta_x * zx_p_minus_q \ - - self.beta_y * zy_p_minus_zy_q \ - - self.beta_d * zd_p_minus_q \ - - self.beta_t * topic_p_minus_q - return batch_loss + return [loss_recon_x, zx_p_minus_q, zy_p_minus_zy_q, zd_p_minus_q, topic_p_minus_q], \ + [1.0, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t] def extract_semantic_features(self, tensor_x): """ diff --git a/domainlab/models/model_jigen.py b/domainlab/models/model_jigen.py index 27c12bbc8..42dba6c92 100644 --- a/domainlab/models/model_jigen.py +++ b/domainlab/models/model_jigen.py @@ -22,7 +22,7 @@ def mk_jigen(parent_class=AModelClassif): For more details, see: Carlucci, Fabio M., et al. "Domain generalization by solving jigsaw puzzles." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. - + Args: parent_class (AModel, optional): Class object determining the task type. Defaults to AModelClassif. @@ -30,7 +30,7 @@ def mk_jigen(parent_class=AModelClassif): Returns: ModelJiGen: model inheriting from parent class - Input Parameters: + Input Parameters: list_str_y: list of labels, list_str_d: list of domains, net_encoder: neural network (input: training data, standard and shuffled), @@ -83,5 +83,5 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): batch_target_scalar = vec_perm_ind loss_perm = F.cross_entropy( logits_which_permutation, batch_target_scalar, reduction="none") - return self.alpha*loss_perm + return [loss_perm], [self.alpha] return ModelJiGen diff --git a/domainlab/models/wrapper_matchdg.py b/domainlab/models/wrapper_matchdg.py index dd305890e..07805980b 100644 --- a/domainlab/models/wrapper_matchdg.py +++ b/domainlab/models/wrapper_matchdg.py @@ -30,7 +30,11 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): return self.net.cal_loss(tensor_x, tensor_y, tensor_d) def cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): - return self.net.cal_loss(tensor_x, tensor_y, tensor_d) # @FIXME: this is wrong + """ + abstract method must be in place, but it will not be called since + cal_loss function is overriden + """ + raise NotImplementedError def forward(self, tensor_x): """