Skip to content

Commit

Permalink
Merge pull request #863 from marrlab/scale_weight
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Jul 26, 2024
2 parents 1520907 + 34975da commit 4301cd4
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 230 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: test if api works
run: poetry run python examples/api/jigen_dann_transformer.py
- name: Generate coverage report
run: rm -rf zoutput && poetry run pytest --cov=domainlab tests/ --cov-report=xml
run: rm -rf zoutput && poetry run pytest --maxfail=1 -vvv --tb=short --cov=domainlab tests/ --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
Expand Down
55 changes: 45 additions & 10 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import abc

import torch
from torch import optim

from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(self, successor_node=None, extend=None):
self.ma_weight_previous_model_params = None
self._dict_previous_para_persist = {}
self._ma_iter = 0
#
self.list_reg_over_task_ratio = None

@property
def model(self):
Expand Down Expand Up @@ -184,11 +187,42 @@ def after_batch(self, epoch, ind_batch):
"""
return

@abc.abstractmethod
def before_tr(self):
"""
before training, probe model performance
"""
self.cal_reg_loss_over_task_loss_ratio()

def cal_reg_loss_over_task_loss_ratio(self):
list_accum_reg_loss = []
loss_task_agg = 0
for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate(
self.loader_tr
):
if ind_batch >= self.aconf.nb4reg_over_task_ratio:
break
tensor_x, tensor_y, tensor_d = (
tensor_x.to(self.device),
tensor_y.to(self.device),
tensor_d.to(self.device),
)
list_reg_loss_tensor, _ = \
self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
list_reg_loss_tensor = [torch.sum(tensor).detach().item()
for tensor in list_reg_loss_tensor]
if ind_batch == 0:
list_accum_reg_loss = list_reg_loss_tensor
else:
list_accum_reg_loss = [reg_loss_accum_tensor + reg_loss_tensor
for reg_loss_accum_tensor,
reg_loss_tensor in
zip(list_accum_reg_loss,
list_reg_loss_tensor)]
tensor_loss_task = self.model.cal_task_loss(tensor_x, tensor_y)
tensor_loss_task = torch.sum(tensor_loss_task).detach().item()
loss_task_agg += tensor_loss_task
self.list_reg_over_task_ratio = [reg_loss / loss_task_agg
for reg_loss in list_accum_reg_loss]

def post_tr(self):
"""
Expand Down Expand Up @@ -233,19 +267,20 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
combine losses of current trainer with self._model.cal_reg_loss, which
can be either a trainer or a model
"""
list_reg_model, list_mu_model = self.decoratee.cal_reg_loss(
tensor_x, tensor_y, tensor_d, others
)
assert len(list_reg_model) == len(list_mu_model)
list_reg_loss_model_tensor, list_mu_model = \
self.decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
assert len(list_reg_loss_model_tensor) == len(list_mu_model)

list_reg_trainer, list_mu_trainer = self._cal_reg_loss(
list_reg_loss_trainer_tensor, list_mu_trainer = self._cal_reg_loss(
tensor_x, tensor_y, tensor_d, others
)
assert len(list_reg_trainer) == len(list_mu_trainer)

list_loss = list_reg_model + list_reg_trainer
assert len(list_reg_loss_trainer_tensor) == len(list_mu_trainer)
# extend the length of list: extend number of regularization loss
# tensor: the element of list is tensor
list_loss_tensor = list_reg_loss_model_tensor + \
list_reg_loss_trainer_tensor
list_mu = list_mu_model + list_mu_trainer
return list_loss, list_mu
return list_loss_tensor, list_mu

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
Expand Down
15 changes: 11 additions & 4 deletions domainlab/algos/trainers/train_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import math
from operator import add

import torch

from domainlab import g_tensor_batch_agg
from domainlab.algos.trainers.a_trainer import AbstractTrainer, mk_opt
from domainlab.algos.trainers.a_trainer import AbstractTrainer


def list_divide(list_val, scalar):
Expand All @@ -24,6 +22,7 @@ def before_tr(self):
check the performance of randomly initialized weight
"""
self.model.evaluate(self.loader_te, self.device)
super().before_tr()

def before_epoch(self):
"""
Expand Down Expand Up @@ -95,8 +94,16 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others):
list_reg_tr_batch, list_mu_tr = self.cal_reg_loss(
tensor_x, tensor_y, tensor_d, others
)

list_mu_tr_normalized = list_mu_tr
if self.list_reg_over_task_ratio:
assert len(list_mu_tr) == len(self.list_reg_over_task_ratio)
list_mu_tr_normalized = \
[mu / reg_over_task_ratio if reg_over_task_ratio != 0
else mu for (mu, reg_over_task_ratio)
in zip(list_mu_tr, self.list_reg_over_task_ratio)]
tensor_batch_reg_loss_penalized = self.model.list_inner_product(
list_reg_tr_batch, list_mu_tr
list_reg_tr_batch, list_mu_tr_normalized
)
assert len(tensor_batch_reg_loss_penalized.shape) == 1
loss_erm_agg = g_tensor_batch_agg(loss_task)
Expand Down
1 change: 1 addition & 0 deletions domainlab/algos/trainers/train_hyper_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def before_tr(self):
total_steps=self.aconf.warmup,
flag_update_epoch=True,
)
super().before_tr()

def tr_epoch(self, epoch):
"""
Expand Down
1 change: 1 addition & 0 deletions domainlab/algos/trainers/train_mldg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def before_tr(self):
flag_accept=False,
)
self.prepare_ziped_loader()
super().before_tr()

def prepare_ziped_loader(self):
"""
Expand Down
9 changes: 9 additions & 0 deletions domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def mk_parser_main():
Set to 0 to turn warmup off.",
)

parser.add_argument(
"-nb4ratio",
"--nb4reg_over_task_ratio",
type=int,
default=1,
help="number of batches for estimating reg loss over task loss ratio \
default 1",
)

parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--dmem", action="store_true", default=False)
parser.add_argument(
Expand Down
Loading

0 comments on commit 4301cd4

Please sign in to comment.