diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 6386e4dcb..3b2be6f3a 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -5,6 +5,7 @@ from torch import optim from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler + def mk_opt(model, aconf): """ create optimizer @@ -41,6 +42,9 @@ def __init__(self, successor_node=None): self.flag_update_hyper_per_epoch = None self.flag_update_hyper_per_batch = None self.epo_loss_tr = None + self.epo_reg_loss_tr = None + self.epo_task_loss_tr = None + self.counter_batch = None self.hyper_scheduler = None self.optimizer = None self.exp = None diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index efb3e2b1e..f63ca79c6 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -2,11 +2,16 @@ basic trainer """ import math +from operator import add from domainlab.algos.trainers.a_trainer import AbstractTrainer from domainlab.algos.trainers.a_trainer import mk_opt +def list_divide(list_val, scalar): + return [ele/scalar for ele in list_val] + + class TrainerBasic(AbstractTrainer): """ basic trainer @@ -19,24 +24,45 @@ def before_tr(self): def tr_epoch(self, epoch): self.model.train() + self.counter_batch = 0.0 self.epo_loss_tr = 0 + self.epo_reg_loss_tr = [0.0 for _ in range(10)] + self.epo_task_loss_tr = 0 for ind_batch, (tensor_x, vec_y, vec_d, *others) in enumerate(self.loader_tr): - self.before_batch(epoch, ind_batch) - tensor_x, vec_y, vec_d = \ - tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device) - self.optimizer.zero_grad() - loss = self.model.cal_loss(tensor_x, vec_y, vec_d, others) - loss = loss.sum() - loss.backward() - self.optimizer.step() - self.epo_loss_tr += loss.detach().item() - self.after_batch(epoch, ind_batch) + self.tr_batch(tensor_x, vec_y, vec_d, others, ind_batch, epoch) + self.epo_loss_tr /= self.counter_batch + self.epo_task_loss_tr /= self.counter_batch + self.epo_reg_loss_tr = list_divide(self.epo_reg_loss_tr, self.counter_batch) assert self.epo_loss_tr is not None assert not math.isnan(self.epo_loss_tr) flag_stop = self.observer.update(epoch) # notify observer assert flag_stop is not None return flag_stop + def handle_r_loss(self, list_b_reg_loss): + list_b_reg_loss_sumed = [ele.sum().detach().item() for ele in list_b_reg_loss] + self.epo_reg_loss_tr = list(map(add, self.epo_reg_loss_tr, list_b_reg_loss_sumed)) + return list_b_reg_loss_sumed + + def tr_batch(self, tensor_x, vec_y, vec_d, others, ind_batch, epoch): + """ + different from self.train_batch(...), which is used for mldg, the current function + is used inside tr_epoch + """ + self.before_batch(epoch, ind_batch) + tensor_x, vec_y, vec_d = \ + tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device) + self.optimizer.zero_grad() + loss, list_loss_reg, loss_task = self.model.cal_loss(tensor_x, vec_y, vec_d, others) + self.handle_r_loss(list_loss_reg) + loss = loss.sum() + loss.backward() + self.optimizer.step() + self.epo_loss_tr += loss.detach().item() + self.epo_task_loss_tr += loss_task.sum().detach().item() + self.after_batch(epoch, ind_batch) + self.counter_batch += 1 + def train_batch(self, tensor_x, vec_y, vec_d, others): """ use a temporary optimizer to update only the model upon a batch of data @@ -46,7 +72,7 @@ def train_batch(self, tensor_x, vec_y, vec_d, others): tensor_x, vec_y, vec_d = \ tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device) optimizer.zero_grad() - loss = self.model.cal_loss(tensor_x, vec_y, vec_d, others) + loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d, others) loss = loss.sum() loss.backward() optimizer.step() diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index b009bb314..60e97411b 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -31,12 +31,16 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): """ list_loss, list_multiplier = self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) loss_reg = self.inner_product(list_loss, list_multiplier) - loss_task = self.multiplier4task_loss * self.cal_task_loss(tensor_x, tensor_y) - return loss_task + loss_reg + loss_task_alone = self.cal_task_loss(tensor_x, tensor_y) + loss_task = self.multiplier4task_loss * loss_task_alone + return loss_task + loss_reg, list_loss, loss_task_alone def inner_product(self, list_loss_scalar, list_multiplier): """ compute inner product between list of scalar loss and multiplier + - the first dimension of the tensor v_reg_loss is mini-batch + the second dimension is the number of regularizers + - the vector mmu has dimension the number of regularizers """ list_tuple = zip(list_loss_scalar, list_multiplier) rst = [mtuple[0]*mtuple[1] for mtuple in list_tuple] diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index b27a12fa8..a21d79bdb 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -3,11 +3,11 @@ """ import abc -import numpy as np import math +import numpy as np import pandas as pd import torch -from torch import nn as nn +from torch import nn from torch.nn import functional as F from domainlab.models.a_model import AModel @@ -197,4 +197,4 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ for ERM to adapt to the interface of other regularized learners """ - return [0], [0] + return torch.Tensor([0]), torch.Tensor([0])