Skip to content

Commit

Permalink
Merge pull request #368 from marrlab/fbopt_reg
Browse files Browse the repository at this point in the history
the correct way to calculate reguarlization loss
  • Loading branch information
smilesun authored Sep 14, 2023
2 parents daac36d + c03e639 commit b04aa58
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 21 deletions.
4 changes: 3 additions & 1 deletion domainlab/algos/trainers/train_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def eval_r_loss(self):
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr_no_drop):
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
b_reg_loss = temp_model.cal_reg_loss(tensor_x, vec_y, vec_d).sum()
tuple_reg_loss = temp_model.cal_reg_loss(tensor_x, vec_y, vec_d)
b_reg_loss = tuple_reg_loss[0][0] # FIXME: this only works when scalar multiplier
b_reg_loss = b_reg_loss.sum().item()
b_task_loss = temp_model.cal_task_loss(tensor_x, vec_y).sum()
# sum will kill the dimension of the mini batch
epo_reg_loss += b_reg_loss
Expand Down
14 changes: 12 additions & 2 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,18 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
"""
calculate the loss
"""
return self.cal_task_loss(tensor_x, tensor_y) + \
self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
list_loss, list_multiplier = self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
loss_reg = self.inner_product(list_loss, list_multiplier)
return self.cal_task_loss(tensor_x, tensor_y) + loss_reg


def inner_product(self, list_loss_scalar, list_multiplier):
"""
compute inner product between list of scalar loss and multiplier
"""
list_tuple = zip(list_loss_scalar, list_multiplier)
rst = [mtuple[0]*mtuple[1] for mtuple in list_tuple]
return sum(rst) # FIXME: is "sum" safe to pytorch?

@abc.abstractmethod
def cal_task_loss(self, tensor_x, tensor_y):
Expand Down
5 changes: 4 additions & 1 deletion domainlab/models/a_model_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,7 @@ def cal_loss_gen_adv(self, x_natural, x_adv, vec_y):
return loss_adv_gen + loss_adv_gen_task.sum()

def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
return 0
"""
for ERM to adapt to the interface of other regularized learners
"""
return [0], [0]
2 changes: 1 addition & 1 deletion domainlab/models/model_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
AutoGradFunReverseMultiply.apply(feat, self.alpha))
_, d_target = tensor_d.max(dim=1)
lc_d = F.cross_entropy(logit_d, d_target, reduction="none")
return self.alpha*lc_d
return [lc_d], [self.alpha]
return ModelDAN
8 changes: 2 additions & 6 deletions domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,6 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
_, d_target = tensor_d.max(dim=1)
lc_d = F.cross_entropy(logit_d, d_target, reduction="none")

loss_reg = loss_recon_x \
- self.beta_d * zd_p_minus_zd_q \
- self.beta_x * zx_p_minus_zx_q \
- self.beta_y * zy_p_minus_zy_q \
+ self.gamma_d * lc_d
return loss_reg
return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \
[1.0, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d]
return ModelDIVA
8 changes: 2 additions & 6 deletions domainlab/models/model_hduva.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,8 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
# reconstruction
z_concat = self.decoder.concat_ytdx(zy_q, topic_q, zd_q, zx_q)
loss_recon_x, _, _ = self.decoder(z_concat, tensor_x)
batch_loss = loss_recon_x \
- self.beta_x * zx_p_minus_q \
- self.beta_y * zy_p_minus_zy_q \
- self.beta_d * zd_p_minus_q \
- self.beta_t * topic_p_minus_q
return batch_loss
return [loss_recon_x, zx_p_minus_q, zy_p_minus_zy_q, zd_p_minus_q, topic_p_minus_q], \
[1.0, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t]

def extract_semantic_features(self, tensor_x):
"""
Expand Down
6 changes: 3 additions & 3 deletions domainlab/models/model_jigen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ def mk_jigen(parent_class=AModelClassif):
For more details, see:
Carlucci, Fabio M., et al. "Domain generalization by solving jigsaw puzzles."
Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
Args:
parent_class (AModel, optional): Class object determining the task
type. Defaults to AModelClassif.
Returns:
ModelJiGen: model inheriting from parent class
Input Parameters:
Input Parameters:
list_str_y: list of labels,
list_str_d: list of domains,
net_encoder: neural network (input: training data, standard and shuffled),
Expand Down Expand Up @@ -83,5 +83,5 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
batch_target_scalar = vec_perm_ind
loss_perm = F.cross_entropy(
logits_which_permutation, batch_target_scalar, reduction="none")
return self.alpha*loss_perm
return [loss_perm], [self.alpha]
return ModelJiGen
6 changes: 5 additions & 1 deletion domainlab/models/wrapper_matchdg.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
return self.net.cal_loss(tensor_x, tensor_y, tensor_d)

def cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
return self.net.cal_loss(tensor_x, tensor_y, tensor_d) # @FIXME: this is wrong
"""
abstract method must be in place, but it will not be called since
cal_loss function is overriden
"""
raise NotImplementedError

def forward(self, tensor_x):
"""
Expand Down

0 comments on commit b04aa58

Please sign in to comment.