Skip to content

Commit

Permalink
cal reg loss in trainer basic
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Oct 5, 2023
1 parent bc912b3 commit a387e06
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 16 deletions.
4 changes: 4 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch import optim
from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler


def mk_opt(model, aconf):
"""
create optimizer
Expand Down Expand Up @@ -41,6 +42,9 @@ def __init__(self, successor_node=None):
self.flag_update_hyper_per_epoch = None
self.flag_update_hyper_per_batch = None
self.epo_loss_tr = None
self.epo_reg_loss_tr = None
self.epo_task_loss_tr = None
self.counter_batch = None
self.hyper_scheduler = None
self.optimizer = None
self.exp = None
Expand Down
48 changes: 37 additions & 11 deletions domainlab/algos/trainers/train_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
basic trainer
"""
import math
from operator import add

from domainlab.algos.trainers.a_trainer import AbstractTrainer
from domainlab.algos.trainers.a_trainer import mk_opt


def list_divide(list_val, scalar):
return [ele/scalar for ele in list_val]


class TrainerBasic(AbstractTrainer):
"""
basic trainer
Expand All @@ -19,24 +24,45 @@ def before_tr(self):

def tr_epoch(self, epoch):
self.model.train()
self.counter_batch = 0.0
self.epo_loss_tr = 0
self.epo_reg_loss_tr = [0.0 for _ in range(10)]
self.epo_task_loss_tr = 0
for ind_batch, (tensor_x, vec_y, vec_d, *others) in enumerate(self.loader_tr):
self.before_batch(epoch, ind_batch)
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
self.optimizer.zero_grad()
loss = self.model.cal_loss(tensor_x, vec_y, vec_d, others)
loss = loss.sum()
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
self.after_batch(epoch, ind_batch)
self.tr_batch(tensor_x, vec_y, vec_d, others, ind_batch, epoch)
self.epo_loss_tr /= self.counter_batch
self.epo_task_loss_tr /= self.counter_batch
self.epo_reg_loss_tr = list_divide(self.epo_reg_loss_tr, self.counter_batch)
assert self.epo_loss_tr is not None
assert not math.isnan(self.epo_loss_tr)
flag_stop = self.observer.update(epoch) # notify observer
assert flag_stop is not None
return flag_stop

def handle_r_loss(self, list_b_reg_loss):
list_b_reg_loss_sumed = [ele.sum().detach().item() for ele in list_b_reg_loss]
self.epo_reg_loss_tr = list(map(add, self.epo_reg_loss_tr, list_b_reg_loss_sumed))
return list_b_reg_loss_sumed

def tr_batch(self, tensor_x, vec_y, vec_d, others, ind_batch, epoch):
"""
different from self.train_batch(...), which is used for mldg, the current function
is used inside tr_epoch
"""
self.before_batch(epoch, ind_batch)
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
self.optimizer.zero_grad()
loss, list_loss_reg, loss_task = self.model.cal_loss(tensor_x, vec_y, vec_d, others)
self.handle_r_loss(list_loss_reg)
loss = loss.sum()
loss.backward()
self.optimizer.step()
self.epo_loss_tr += loss.detach().item()
self.epo_task_loss_tr += loss_task.sum().detach().item()
self.after_batch(epoch, ind_batch)
self.counter_batch += 1

def train_batch(self, tensor_x, vec_y, vec_d, others):
"""
use a temporary optimizer to update only the model upon a batch of data
Expand All @@ -46,7 +72,7 @@ def train_batch(self, tensor_x, vec_y, vec_d, others):
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
optimizer.zero_grad()
loss = self.model.cal_loss(tensor_x, vec_y, vec_d, others)
loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d, others)
loss = loss.sum()
loss.backward()
optimizer.step()
8 changes: 6 additions & 2 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
"""
list_loss, list_multiplier = self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
loss_reg = self.inner_product(list_loss, list_multiplier)
loss_task = self.multiplier4task_loss * self.cal_task_loss(tensor_x, tensor_y)
return loss_task + loss_reg
loss_task_alone = self.cal_task_loss(tensor_x, tensor_y)
loss_task = self.multiplier4task_loss * loss_task_alone
return loss_task + loss_reg, list_loss, loss_task_alone

def inner_product(self, list_loss_scalar, list_multiplier):
"""
compute inner product between list of scalar loss and multiplier
- the first dimension of the tensor v_reg_loss is mini-batch
the second dimension is the number of regularizers
- the vector mmu has dimension the number of regularizers
"""
list_tuple = zip(list_loss_scalar, list_multiplier)
rst = [mtuple[0]*mtuple[1] for mtuple in list_tuple]
Expand Down
6 changes: 3 additions & 3 deletions domainlab/models/a_model_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""

import abc
import numpy as np
import math
import numpy as np
import pandas as pd
import torch
from torch import nn as nn
from torch import nn
from torch.nn import functional as F

from domainlab.models.a_model import AModel
Expand Down Expand Up @@ -197,4 +197,4 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
for ERM to adapt to the interface of other regularized learners
"""
return [0], [0]
return torch.Tensor([0]), torch.Tensor([0])

0 comments on commit a387e06

Please sign in to comment.