diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d93b8530..86ca0b445 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - name: test if api works run: poetry run python examples/api/jigen_dann_transformer.py - name: Generate coverage report - run: rm -rf zoutput && poetry run pytest --cov=domainlab tests/ --cov-report=xml + run: rm -rf zoutput && poetry run pytest --maxfail=1 -vvv --tb=short --cov=domainlab tests/ --cov-report=xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 177a2b28e..67120a1b6 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -3,6 +3,7 @@ """ import abc +import torch from torch import optim from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -88,6 +89,8 @@ def __init__(self, successor_node=None, extend=None): self.ma_weight_previous_model_params = None self._dict_previous_para_persist = {} self._ma_iter = 0 + # + self.list_reg_over_task_ratio = None @property def model(self): @@ -184,11 +187,42 @@ def after_batch(self, epoch, ind_batch): """ return - @abc.abstractmethod def before_tr(self): """ before training, probe model performance """ + self.cal_reg_loss_over_task_loss_ratio() + + def cal_reg_loss_over_task_loss_ratio(self): + list_accum_reg_loss = [] + loss_task_agg = 0 + for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( + self.loader_tr + ): + if ind_batch >= self.aconf.nb4reg_over_task_ratio: + break + tensor_x, tensor_y, tensor_d = ( + tensor_x.to(self.device), + tensor_y.to(self.device), + tensor_d.to(self.device), + ) + list_reg_loss_tensor, _ = \ + self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + list_reg_loss_tensor = [torch.sum(tensor).detach().item() + for tensor in list_reg_loss_tensor] + if ind_batch == 0: + list_accum_reg_loss = list_reg_loss_tensor + else: + list_accum_reg_loss = [reg_loss_accum_tensor + reg_loss_tensor + for reg_loss_accum_tensor, + reg_loss_tensor in + zip(list_accum_reg_loss, + list_reg_loss_tensor)] + tensor_loss_task = self.model.cal_task_loss(tensor_x, tensor_y) + tensor_loss_task = torch.sum(tensor_loss_task).detach().item() + loss_task_agg += tensor_loss_task + self.list_reg_over_task_ratio = [reg_loss / loss_task_agg + for reg_loss in list_accum_reg_loss] def post_tr(self): """ @@ -233,19 +267,20 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): combine losses of current trainer with self._model.cal_reg_loss, which can be either a trainer or a model """ - list_reg_model, list_mu_model = self.decoratee.cal_reg_loss( - tensor_x, tensor_y, tensor_d, others - ) - assert len(list_reg_model) == len(list_mu_model) + list_reg_loss_model_tensor, list_mu_model = \ + self.decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + assert len(list_reg_loss_model_tensor) == len(list_mu_model) - list_reg_trainer, list_mu_trainer = self._cal_reg_loss( + list_reg_loss_trainer_tensor, list_mu_trainer = self._cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) - assert len(list_reg_trainer) == len(list_mu_trainer) - - list_loss = list_reg_model + list_reg_trainer + assert len(list_reg_loss_trainer_tensor) == len(list_mu_trainer) + # extend the length of list: extend number of regularization loss + # tensor: the element of list is tensor + list_loss_tensor = list_reg_loss_model_tensor + \ + list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer - return list_loss, list_mu + return list_loss_tensor, list_mu def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 7e6d7cac7..179848467 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -4,10 +4,8 @@ import math from operator import add -import torch - from domainlab import g_tensor_batch_agg -from domainlab.algos.trainers.a_trainer import AbstractTrainer, mk_opt +from domainlab.algos.trainers.a_trainer import AbstractTrainer def list_divide(list_val, scalar): @@ -24,6 +22,7 @@ def before_tr(self): check the performance of randomly initialized weight """ self.model.evaluate(self.loader_te, self.device) + super().before_tr() def before_epoch(self): """ @@ -95,8 +94,16 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_reg_tr_batch, list_mu_tr = self.cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) + + list_mu_tr_normalized = list_mu_tr + if self.list_reg_over_task_ratio: + assert len(list_mu_tr) == len(self.list_reg_over_task_ratio) + list_mu_tr_normalized = \ + [mu / reg_over_task_ratio if reg_over_task_ratio != 0 + else mu for (mu, reg_over_task_ratio) + in zip(list_mu_tr, self.list_reg_over_task_ratio)] tensor_batch_reg_loss_penalized = self.model.list_inner_product( - list_reg_tr_batch, list_mu_tr + list_reg_tr_batch, list_mu_tr_normalized ) assert len(tensor_batch_reg_loss_penalized.shape) == 1 loss_erm_agg = g_tensor_batch_agg(loss_task) diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index 2e60bf5e8..0a89e7691 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -54,6 +54,7 @@ def before_tr(self): total_steps=self.aconf.warmup, flag_update_epoch=True, ) + super().before_tr() def tr_epoch(self, epoch): """ diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index e91310adf..2c7376636 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -35,6 +35,7 @@ def before_tr(self): flag_accept=False, ) self.prepare_ziped_loader() + super().before_tr() def prepare_ziped_loader(self): """ diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 1b8593e8f..cf9028cda 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -113,6 +113,15 @@ def mk_parser_main(): Set to 0 to turn warmup off.", ) + parser.add_argument( + "-nb4ratio", + "--nb4reg_over_task_ratio", + type=int, + default=1, + help="number of batches for estimating reg loss over task loss ratio \ + default 1", + ) + parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--dmem", action="store_true", default=False) parser.add_argument( diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 8a28567b0..03012bb18 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -1,108 +1,19 @@ """ executing mk_exp multiple times will cause deep copy to be called multiple times, pytest will show process got killed. """ -from torch import nn -from torchvision import models as torchvisionmodels -from torchvision.models import ResNet50_Weights - -from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor -from domainlab.algos.msels.c_msel_val import MSelValPerf -from domainlab.algos.observers.b_obvisitor import ObVisitor -from domainlab.arg_parser import mk_parser_main -from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault -from domainlab.exp.exp_main import Exp -from domainlab.models.model_erm import mk_erm -from domainlab.tasks.task_dset import mk_task_dset -from domainlab.tasks.utils_task import ImSize -from domainlab.utils.utils_cuda import get_device - - -def mk_exp( - task, - model, - trainer: str, - test_domain: str, - batchsize: int, - alone=True, - force_best_val=False, - msel_loss_tr=False, -): - """ - Creates a custom experiment. The user can specify the input parameters. - - Input Parameters: - - task: create a task to a custom dataset by importing "mk_task_dset" - function from - "domainlab.tasks.task_dset". For more explanation on the input params - refer to the - documentation found in "domainlab.tasks.task_dset.py". - - model: create a model [NameOfModel] by importing "mk_[NameOfModel]" - function from - "domainlab.models.model_[NameOfModel]". For a concrete example and - explanation of the input - params refer to the documentation found in - "domainlab.models.model_[NameOfModel].py" - - trainer: string, - - test_domain: string, - - batch size: int - - Returns: experiment - """ - - str_arg = f"--model=apimodel --trainer={trainer} \ - --te_d={test_domain} --bs={batchsize}" - if msel_loss_tr: - str_arg = f"--model=apimodel --trainer={trainer} \ - --te_d={test_domain} --bs={batchsize} --msel=loss_tr" - - parser = mk_parser_main() - conf = parser.parse_args(str_arg.split()) - device = get_device(conf) - if alone: - model_sel = MSelOracleVisitor() - else: - model_sel = MSelOracleVisitor(MSelValPerf(max_es=0)) - if force_best_val: - model_sel.msel._best_val_acc = 1.0 - observer = ObVisitor(model_sel) - exp = Exp(conf, task, model=model, observer=observer) - model_sel.update(epoch=1, clear_counter=True) - return exp - +from tests.utils_task_model import mk_exp, mk_model, mk_task def test_msel_oracle(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) - + task = mk_task() + model = mk_model(task) # make trainer for model - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) + exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=2) exp.execute(num_epochs=2) - + exp.clean_up() del exp @@ -110,39 +21,15 @@ def test_msel_oracle1(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) - - # make trainer for model - + task = mk_task() + model = mk_model(task) exp = mk_exp( - task, model, trainer="mldg", test_domain="domain1", batchsize=32, alone=False + task, model, trainer="mldg", test_domain="domain1", batchsize=2, alone=False ) exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) + exp.clean_up() del exp @@ -150,116 +37,32 @@ def test_msel_oracle2(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) + task = mk_task() + model = mk_model(task) # make trainer for model - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) + exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=2) exp.execute(num_epochs=2) + exp.clean_up() + del exp def test_msel_oracle3(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) + task = mk_task() + model = mk_model(task) exp = mk_exp( task, model, trainer="mldg", test_domain="domain1", - batchsize=32, + batchsize=2, alone=False, force_best_val=True, ) exp.execute(num_epochs=2) - del exp - - -def test_msel_oracle4(): - """ - return trainer, model, observer - """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) - exp = mk_exp( - task, - model, - trainer="mldg", - test_domain="domain1", - batchsize=32, - alone=False, - msel_loss_tr=True, - ) - exp.execute(num_epochs=2) - exp.trainer.observer.model_sel.msel.best_loss = 0 - exp.trainer.observer.model_sel.msel.update(epoch = 1, clear_counter=True) + exp.clean_up() del exp diff --git a/tests/test_msel_oracle2.py b/tests/test_msel_oracle2.py new file mode 100644 index 000000000..651a38f2b --- /dev/null +++ b/tests/test_msel_oracle2.py @@ -0,0 +1,27 @@ +""" +executing mk_exp multiple times will cause deep copy to be called multiple times, pytest will show process got killed. +""" +from tests.utils_task_model import mk_exp, mk_model, mk_task + + +def test_msel_oracle4(): + """ + return trainer, model, observer + """ + task = mk_task() + model = mk_model(task) + # specify backbone to use + exp = mk_exp( + task, + model, + trainer="mldg", + test_domain="domain1", + batchsize=2, + alone=False, + msel_loss_tr=True, + ) + exp.execute(num_epochs=2) + exp.trainer.observer.model_sel.msel.best_loss = 0 + exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) + exp.clean_up() + del exp diff --git a/tests/utils_task_model.py b/tests/utils_task_model.py new file mode 100644 index 000000000..4fa70796a --- /dev/null +++ b/tests/utils_task_model.py @@ -0,0 +1,94 @@ +from torch import nn +from torchvision import models as torchvisionmodels +from torchvision.models import ResNet50_Weights + +from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.observers.b_obvisitor import ObVisitor +from domainlab.arg_parser import mk_parser_main +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault +from domainlab.exp.exp_main import Exp +from domainlab.models.model_erm import mk_erm +from domainlab.tasks.task_dset import mk_task_dset +from domainlab.tasks.utils_task import ImSize + + +def mk_model(task): + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + num_final_in = backbone.fc.in_features + backbone.fc = nn.Linear(num_final_in, task.dim_y) + + # specify model to use + model = mk_erm(list_str_y=task.list_str_y)(backbone) + return model + +def mk_exp( + task, + model, + trainer: str, + test_domain: str, + batchsize: int, + alone=True, + force_best_val=False, + msel_loss_tr=False, +): + """ + Creates a custom experiment. The user can specify the input parameters. + + Input Parameters: + - task: create a task to a custom dataset by importing "mk_task_dset" + function from + "domainlab.tasks.task_dset". For more explanation on the input params + refer to the + documentation found in "domainlab.tasks.task_dset.py". + - model: create a model [NameOfModel] by importing "mk_[NameOfModel]" + function from + "domainlab.models.model_[NameOfModel]". For a concrete example and + explanation of the input + params refer to the documentation found in + "domainlab.models.model_[NameOfModel].py" + - trainer: string, + - test_domain: string, + - batch size: int + + Returns: experiment + """ + + str_arg = f"--model=apimodel --trainer={trainer} \ + --te_d={test_domain} --bs={batchsize}" + if msel_loss_tr: + str_arg = f"--model=apimodel --trainer={trainer} \ + --te_d={test_domain} --bs={batchsize} --msel=loss_tr" + + parser = mk_parser_main() + conf = parser.parse_args(str_arg.split()) + if alone: + model_sel = MSelOracleVisitor() + else: + model_sel = MSelOracleVisitor(MSelValPerf(max_es=0)) + if force_best_val: + model_sel.msel._best_val_acc = 1.0 + observer = ObVisitor(model_sel) + exp = Exp(conf, task, model=model, observer=observer) + model_sel.update(epoch=1, clear_counter=True) + return exp + + +def mk_task(): + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) + return task