Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

estimate scale ratio at begin #863

Merged
merged 20 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- name: test if api works
run: poetry run python examples/api/jigen_dann_transformer.py
- name: Generate coverage report
run: rm -rf zoutput && poetry run pytest --cov=domainlab tests/ --cov-report=xml
run: rm -rf zoutput && poetry run pytest --maxfail=1 -vvv --tb=short --cov=domainlab tests/ --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
Expand Down
55 changes: 45 additions & 10 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import abc

import torch
from torch import optim

from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler
Expand Down Expand Up @@ -88,6 +89,8 @@ def __init__(self, successor_node=None, extend=None):
self.ma_weight_previous_model_params = None
self._dict_previous_para_persist = {}
self._ma_iter = 0
#
self.list_reg_over_task_ratio = None

@property
def model(self):
Expand Down Expand Up @@ -184,11 +187,42 @@ def after_batch(self, epoch, ind_batch):
"""
return

@abc.abstractmethod
def before_tr(self):
"""
before training, probe model performance
"""
self.cal_reg_loss_over_task_loss_ratio()

def cal_reg_loss_over_task_loss_ratio(self):
list_accum_reg_loss = []
loss_task_agg = 0
for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate(
self.loader_tr
):
if ind_batch >= self.aconf.nb4reg_over_task_ratio:
break
tensor_x, tensor_y, tensor_d = (
tensor_x.to(self.device),
tensor_y.to(self.device),
tensor_d.to(self.device),
)
list_reg_loss_tensor, _ = \
self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
list_reg_loss_tensor = [torch.sum(tensor).detach().item()
for tensor in list_reg_loss_tensor]
if ind_batch == 0:
list_accum_reg_loss = list_reg_loss_tensor
else:
list_accum_reg_loss = [reg_loss_accum_tensor + reg_loss_tensor
for reg_loss_accum_tensor,
reg_loss_tensor in
zip(list_accum_reg_loss,
list_reg_loss_tensor)]
tensor_loss_task = self.model.cal_task_loss(tensor_x, tensor_y)
tensor_loss_task = torch.sum(tensor_loss_task).detach().item()
loss_task_agg += tensor_loss_task
self.list_reg_over_task_ratio = [reg_loss / loss_task_agg
for reg_loss in list_accum_reg_loss]

def post_tr(self):
"""
Expand Down Expand Up @@ -233,19 +267,20 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
combine losses of current trainer with self._model.cal_reg_loss, which
can be either a trainer or a model
"""
list_reg_model, list_mu_model = self.decoratee.cal_reg_loss(
tensor_x, tensor_y, tensor_d, others
)
assert len(list_reg_model) == len(list_mu_model)
list_reg_loss_model_tensor, list_mu_model = \
self.decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
assert len(list_reg_loss_model_tensor) == len(list_mu_model)

list_reg_trainer, list_mu_trainer = self._cal_reg_loss(
list_reg_loss_trainer_tensor, list_mu_trainer = self._cal_reg_loss(
tensor_x, tensor_y, tensor_d, others
)
assert len(list_reg_trainer) == len(list_mu_trainer)

list_loss = list_reg_model + list_reg_trainer
assert len(list_reg_loss_trainer_tensor) == len(list_mu_trainer)
# extend the length of list: extend number of regularization loss
# tensor: the element of list is tensor
list_loss_tensor = list_reg_loss_model_tensor + \
list_reg_loss_trainer_tensor
list_mu = list_mu_model + list_mu_trainer
return list_loss, list_mu
return list_loss_tensor, list_mu

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
Expand Down
15 changes: 11 additions & 4 deletions domainlab/algos/trainers/train_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import math
from operator import add

import torch

from domainlab import g_tensor_batch_agg
from domainlab.algos.trainers.a_trainer import AbstractTrainer, mk_opt
from domainlab.algos.trainers.a_trainer import AbstractTrainer


def list_divide(list_val, scalar):
Expand All @@ -24,6 +22,7 @@ def before_tr(self):
check the performance of randomly initialized weight
"""
self.model.evaluate(self.loader_te, self.device)
super().before_tr()

def before_epoch(self):
"""
Expand Down Expand Up @@ -95,8 +94,16 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others):
list_reg_tr_batch, list_mu_tr = self.cal_reg_loss(
tensor_x, tensor_y, tensor_d, others
)

list_mu_tr_normalized = list_mu_tr
if self.list_reg_over_task_ratio:
assert len(list_mu_tr) == len(self.list_reg_over_task_ratio)
list_mu_tr_normalized = \
[mu / reg_over_task_ratio if reg_over_task_ratio != 0
else mu for (mu, reg_over_task_ratio)
in zip(list_mu_tr, self.list_reg_over_task_ratio)]
tensor_batch_reg_loss_penalized = self.model.list_inner_product(
list_reg_tr_batch, list_mu_tr
list_reg_tr_batch, list_mu_tr_normalized
)
assert len(tensor_batch_reg_loss_penalized.shape) == 1
loss_erm_agg = g_tensor_batch_agg(loss_task)
Expand Down
1 change: 1 addition & 0 deletions domainlab/algos/trainers/train_hyper_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def before_tr(self):
total_steps=self.aconf.warmup,
flag_update_epoch=True,
)
super().before_tr()

def tr_epoch(self, epoch):
"""
Expand Down
1 change: 1 addition & 0 deletions domainlab/algos/trainers/train_mldg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def before_tr(self):
flag_accept=False,
)
self.prepare_ziped_loader()
super().before_tr()

def prepare_ziped_loader(self):
"""
Expand Down
9 changes: 9 additions & 0 deletions domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def mk_parser_main():
Set to 0 to turn warmup off.",
)

parser.add_argument(
"-nb4ratio",
"--nb4reg_over_task_ratio",
type=int,
default=1,
help="number of batches for estimating reg loss over task loss ratio \
default 1",
)

parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument("--dmem", action="store_true", default=False)
parser.add_argument(
Expand Down
Loading
Loading