From 1005760fabf46bbbfc16a9705e72605053a8d594 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 24 Jul 2024 16:32:12 +0200 Subject: [PATCH 01/18] estimate scale ratio at begin --- domainlab/algos/trainers/train_basic.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 7e6d7cac7..c0a363ec9 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -24,6 +24,19 @@ def before_tr(self): check the performance of randomly initialized weight """ self.model.evaluate(self.loader_te, self.device) + list_accum_reg_loss = [] + loss_task_agg = 0 + for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( + self.loader_tr + ): + list_reg_loss, _ = \ + self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + if ind_batch == 0: + list_accum_reg_loss = list_reg_loss + list_accum_reg_loss = [reg_loss_accum + reg_loss \ + for reg_loss_accum, reg_loss in + zip(list_accum_reg_loss, list_reg_loss)] + loss_task = self.model.cal_task_loss(tensor_x, tensor_y) def before_epoch(self): """ From 209a137186f46cb5e8e8294310ab2ea5f40d963d Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 25 Jul 2024 12:17:00 +0200 Subject: [PATCH 02/18] ratio estimation done --- domainlab/algos/trainers/a_trainer.py | 50 +++++++++++++++---- domainlab/algos/trainers/train_basic.py | 18 +------ .../algos/trainers/train_hyper_scheduler.py | 1 + 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 8af15bf25..cddfc24cf 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -3,6 +3,7 @@ """ import abc +import torch from torch import optim from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler @@ -88,6 +89,8 @@ def __init__(self, successor_node=None, extend=None): self.ma_weight_previous_model_params = None self._ma_dict_para_persist = {} self._ma_iter = 0 + # + self.list_reg_over_task_ratio = None @property def model(self): @@ -184,11 +187,37 @@ def after_batch(self, epoch, ind_batch): """ return - @abc.abstractmethod def before_tr(self): """ before training, probe model performance """ + list_accum_reg_loss = [] + loss_task_agg = 0 + for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( + self.loader_tr + ): + tensor_x, tensor_y, tensor_d = ( + tensor_x.to(self.device), + tensor_y.to(self.device), + tensor_d.to(self.device), + ) + list_reg_loss_tensor, _ = \ + self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + list_reg_loss_tensor = [torch.sum(tensor).detach().item() + for tensor in list_reg_loss_tensor] + if ind_batch == 0: + list_accum_reg_loss = list_reg_loss_tensor + else: + list_accum_reg_loss = [reg_loss_accum_tensor + reg_loss_tensor + for reg_loss_accum_tensor, + reg_loss_tensor in + zip(list_accum_reg_loss, + list_reg_loss_tensor)] + tensor_loss_task = self.model.cal_task_loss(tensor_x, tensor_y) + tensor_loss_task = torch.sum(tensor_loss_task).detach().item() + loss_task_agg += tensor_loss_task + self.list_reg_over_task_ratio = [reg_loss / loss_task_agg + for reg_loss in list_accum_reg_loss] def post_tr(self): """ @@ -233,19 +262,20 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): combine losses of current trainer with self._model.cal_reg_loss, which can be either a trainer or a model """ - list_reg_model, list_mu_model = self.decoratee.cal_reg_loss( - tensor_x, tensor_y, tensor_d, others - ) - assert len(list_reg_model) == len(list_mu_model) + list_reg_loss_model_tensor, list_mu_model = \ + self.decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) + assert len(list_reg_loss_model_tensor) == len(list_mu_model) - list_reg_trainer, list_mu_trainer = self._cal_reg_loss( + list_reg_loss_trainer_tensor, list_mu_trainer = self._cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) - assert len(list_reg_trainer) == len(list_mu_trainer) - - list_loss = list_reg_model + list_reg_trainer + assert len(list_reg_loss_trainer_tensor) == len(list_mu_trainer) + # extend the length of list: extend number of regularization loss + # tensor: the element of list is tensor + list_loss_tensor = list_reg_loss_model_tensor + \ + list_reg_loss_trainer_tensor list_mu = list_mu_model + list_mu_trainer - return list_loss, list_mu + return list_loss_tensor, list_mu def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): """ diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index c0a363ec9..406879b07 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -4,10 +4,8 @@ import math from operator import add -import torch - from domainlab import g_tensor_batch_agg -from domainlab.algos.trainers.a_trainer import AbstractTrainer, mk_opt +from domainlab.algos.trainers.a_trainer import AbstractTrainer def list_divide(list_val, scalar): @@ -24,19 +22,7 @@ def before_tr(self): check the performance of randomly initialized weight """ self.model.evaluate(self.loader_te, self.device) - list_accum_reg_loss = [] - loss_task_agg = 0 - for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( - self.loader_tr - ): - list_reg_loss, _ = \ - self.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) - if ind_batch == 0: - list_accum_reg_loss = list_reg_loss - list_accum_reg_loss = [reg_loss_accum + reg_loss \ - for reg_loss_accum, reg_loss in - zip(list_accum_reg_loss, list_reg_loss)] - loss_task = self.model.cal_task_loss(tensor_x, tensor_y) + super().before_tr() def before_epoch(self): """ diff --git a/domainlab/algos/trainers/train_hyper_scheduler.py b/domainlab/algos/trainers/train_hyper_scheduler.py index 2e60bf5e8..0a89e7691 100644 --- a/domainlab/algos/trainers/train_hyper_scheduler.py +++ b/domainlab/algos/trainers/train_hyper_scheduler.py @@ -54,6 +54,7 @@ def before_tr(self): total_steps=self.aconf.warmup, flag_update_epoch=True, ) + super().before_tr() def tr_epoch(self, epoch): """ From ad6e7f44738da99e1e3c0a2ab8daa760ee8ea39a Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 25 Jul 2024 12:29:34 +0200 Subject: [PATCH 03/18] reg over task ratio inserted --- domainlab/algos/trainers/train_basic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 406879b07..02d0c9f29 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -94,8 +94,11 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_reg_tr_batch, list_mu_tr = self.cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) + list_mu_tr_normalized = \ + [mu / reg_over_task_ratio for (mu, reg_over_task_ratio) + in zip(list_mu_tr, self.list_reg_over_task_ratio)] tensor_batch_reg_loss_penalized = self.model.list_inner_product( - list_reg_tr_batch, list_mu_tr + list_reg_tr_batch, list_mu_tr_normalized ) assert len(tensor_batch_reg_loss_penalized.shape) == 1 loss_erm_agg = g_tensor_batch_agg(loss_task) From 9cb3e093410b83809d18a38858858494a040b626 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 25 Jul 2024 13:28:02 +0200 Subject: [PATCH 04/18] . --- domainlab/algos/trainers/train_basic.py | 9 ++++++--- domainlab/algos/trainers/train_mldg.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 02d0c9f29..363157d86 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -94,9 +94,12 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_reg_tr_batch, list_mu_tr = self.cal_reg_loss( tensor_x, tensor_y, tensor_d, others ) - list_mu_tr_normalized = \ - [mu / reg_over_task_ratio for (mu, reg_over_task_ratio) - in zip(list_mu_tr, self.list_reg_over_task_ratio)] + + list_mu_tr_normalized = list_mu_tr + if self.list_reg_over_task_ratio: + list_mu_tr_normalized = \ + [mu / reg_over_task_ratio for (mu, reg_over_task_ratio) + in zip(list_mu_tr, self.list_reg_over_task_ratio)] tensor_batch_reg_loss_penalized = self.model.list_inner_product( list_reg_tr_batch, list_mu_tr_normalized ) diff --git a/domainlab/algos/trainers/train_mldg.py b/domainlab/algos/trainers/train_mldg.py index e91310adf..2c7376636 100644 --- a/domainlab/algos/trainers/train_mldg.py +++ b/domainlab/algos/trainers/train_mldg.py @@ -35,6 +35,7 @@ def before_tr(self): flag_accept=False, ) self.prepare_ziped_loader() + super().before_tr() def prepare_ziped_loader(self): """ From b90c79af8755b12e2f78ee4627dd26702d4bb805 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 25 Jul 2024 15:19:18 +0200 Subject: [PATCH 05/18] fix divide by zero --- domainlab/algos/trainers/train_basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 363157d86..55c3802c9 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -98,7 +98,8 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_mu_tr_normalized = list_mu_tr if self.list_reg_over_task_ratio: list_mu_tr_normalized = \ - [mu / reg_over_task_ratio for (mu, reg_over_task_ratio) + [mu / reg_over_task_ratio if reg_over_task_ratio != 0 + else mu for (mu, reg_over_task_ratio) in zip(list_mu_tr, self.list_reg_over_task_ratio)] tensor_batch_reg_loss_penalized = self.model.list_inner_product( list_reg_tr_batch, list_mu_tr_normalized From f5efc7617b3452e122b570aab0b7552d85796a18 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 25 Jul 2024 18:10:24 +0200 Subject: [PATCH 06/18] Update train_basic.py --- domainlab/algos/trainers/train_basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/domainlab/algos/trainers/train_basic.py b/domainlab/algos/trainers/train_basic.py index 55c3802c9..179848467 100644 --- a/domainlab/algos/trainers/train_basic.py +++ b/domainlab/algos/trainers/train_basic.py @@ -97,6 +97,7 @@ def cal_loss(self, tensor_x, tensor_y, tensor_d, others): list_mu_tr_normalized = list_mu_tr if self.list_reg_over_task_ratio: + assert len(list_mu_tr) == len(self.list_reg_over_task_ratio) list_mu_tr_normalized = \ [mu / reg_over_task_ratio if reg_over_task_ratio != 0 else mu for (mu, reg_over_task_ratio) From 55bba059dbfde5ac6144a22829e1c90a17d22b38 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Thu, 25 Jul 2024 22:20:45 +0200 Subject: [PATCH 07/18] Update ci.yml, pytest config --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9d93b8530..86ca0b445 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: - name: test if api works run: poetry run python examples/api/jigen_dann_transformer.py - name: Generate coverage report - run: rm -rf zoutput && poetry run pytest --cov=domainlab tests/ --cov-report=xml + run: rm -rf zoutput && poetry run pytest --maxfail=1 -vvv --tb=short --cov=domainlab tests/ --cov-report=xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: From aa53b31ffd419f6ea11a2f4a0ca054feb1a8a280 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 26 Jul 2024 00:20:19 +0200 Subject: [PATCH 08/18] torch no_grad for ratio calculation --- domainlab/algos/trainers/a_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index d02cf1729..24554fee2 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -191,6 +191,10 @@ def before_tr(self): """ before training, probe model performance """ + with torch.no_grad(): + self.cal_reg_loss_over_task_loss_ratio() + + def cal_reg_loss_over_task_loss_ratio(self): list_accum_reg_loss = [] loss_task_agg = 0 for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( From bc4c4edccd97dad5069d52a82c44184bbc532d28 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Fri, 26 Jul 2024 00:27:22 +0200 Subject: [PATCH 09/18] remove torch_no_grad: dial, dann need grad --- domainlab/algos/trainers/a_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 24554fee2..2d9673dba 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -191,8 +191,7 @@ def before_tr(self): """ before training, probe model performance """ - with torch.no_grad(): - self.cal_reg_loss_over_task_loss_ratio() + self.cal_reg_loss_over_task_loss_ratio() def cal_reg_loss_over_task_loss_ratio(self): list_accum_reg_loss = [] From 15d776d62acc1e0d1c0fc91bc47a030b6be1e63f Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 26 Jul 2024 08:48:44 +0200 Subject: [PATCH 10/18] by default no ratio estimation --- domainlab/algos/trainers/a_trainer.py | 2 ++ domainlab/arg_parser.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 2d9673dba..e62fb1fdd 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -199,6 +199,8 @@ def cal_reg_loss_over_task_loss_ratio(self): for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): + if ind_batch >= self.aconf.nb4ratio: + break tensor_x, tensor_y, tensor_d = ( tensor_x.to(self.device), tensor_y.to(self.device), diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index 1b8593e8f..ead22cd47 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -113,6 +113,15 @@ def mk_parser_main(): Set to 0 to turn warmup off.", ) + parser.add_argument( + "-nb4ratio", + "--nb4reg_over_task_ratio", + type=int, + default=0, + help="number of batches for estimating reg loss over task loss ratio \ + default 0", + ) + parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--dmem", action="store_true", default=False) parser.add_argument( From 589c428671a3fe3d53d26651d669d053bd21049d Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 26 Jul 2024 08:53:52 +0200 Subject: [PATCH 11/18] . --- domainlab/algos/trainers/a_trainer.py | 2 +- domainlab/arg_parser.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index e62fb1fdd..67120a1b6 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -199,7 +199,7 @@ def cal_reg_loss_over_task_loss_ratio(self): for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): - if ind_batch >= self.aconf.nb4ratio: + if ind_batch >= self.aconf.nb4reg_over_task_ratio: break tensor_x, tensor_y, tensor_d = ( tensor_x.to(self.device), diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index ead22cd47..cf9028cda 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -117,9 +117,9 @@ def mk_parser_main(): "-nb4ratio", "--nb4reg_over_task_ratio", type=int, - default=0, + default=1, help="number of batches for estimating reg loss over task loss ratio \ - default 0", + default 1", ) parser.add_argument("--debug", action="store_true", default=False) From c5999461897b89b78511a7e3f07719e3b996407b Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 26 Jul 2024 12:05:52 +0200 Subject: [PATCH 12/18] clean test msel oracle --- tests/test_msel_oracle.py | 112 +++++++++----------------------------- 1 file changed, 25 insertions(+), 87 deletions(-) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 8a28567b0..3eaef27d2 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -17,6 +17,15 @@ from domainlab.utils.utils_cuda import get_device +def mk_model(task): + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + num_final_in = backbone.fc.in_features + backbone.fc = nn.Linear(num_final_in, task.dim_y) + + # specify model to use + model = mk_erm(list_str_y=task.list_str_y)(backbone) + return model + def mk_exp( task, model, @@ -70,10 +79,7 @@ def mk_exp( return exp -def test_msel_oracle(): - """ - return trainer, model, observer - """ +def mk_task(): task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") task.add_domain( name="domain1", @@ -90,8 +96,15 @@ def test_msel_oracle(): dset_tr=DsetMNISTColorSoloDefault(4), dset_val=DsetMNISTColorSoloDefault(5), ) + return task + +def test_msel_oracle(): + """ + return trainer, model, observer + """ # specify backbone to use + task = mk_task() backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) @@ -110,24 +123,8 @@ def test_msel_oracle1(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use + task = mk_task() + # specify backbone to use backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) @@ -150,23 +147,7 @@ def test_msel_oracle2(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - + task = mk_task() # specify backbone to use backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features @@ -184,30 +165,8 @@ def test_msel_oracle3(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) + task = mk_task() + model = mk_model(task) exp = mk_exp( task, @@ -226,30 +185,9 @@ def test_msel_oracle4(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - + task = mk_task() + model = mk_model(task) # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) exp = mk_exp( task, model, @@ -261,5 +199,5 @@ def test_msel_oracle4(): ) exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.best_loss = 0 - exp.trainer.observer.model_sel.msel.update(epoch = 1, clear_counter=True) + exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) del exp From a5057072d494aedcea989684a09c47d7ff56c18a Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 26 Jul 2024 12:26:22 +0200 Subject: [PATCH 13/18] . --- tests/test_msel_oracle.py | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 3eaef27d2..551860149 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -105,13 +105,7 @@ def test_msel_oracle(): """ # specify backbone to use task = mk_task() - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) - + model = mk_model(task) # make trainer for model exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) exp.execute(num_epochs=2) @@ -124,16 +118,7 @@ def test_msel_oracle1(): return trainer, model, observer """ task = mk_task() - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) - - # make trainer for model - + model = mk_model(task) exp = mk_exp( task, model, trainer="mldg", test_domain="domain1", batchsize=32, alone=False ) @@ -148,13 +133,7 @@ def test_msel_oracle2(): return trainer, model, observer """ task = mk_task() - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) + model = mk_model(task) # make trainer for model exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) From 8efecac6af9df0da6877652d23cecd809a78da3d Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 26 Jul 2024 12:57:22 +0200 Subject: [PATCH 14/18] decrease batch size for msel_orcale test --- tests/test_msel_oracle.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 551860149..8bcf37a92 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -66,7 +66,6 @@ def mk_exp( parser = mk_parser_main() conf = parser.parse_args(str_arg.split()) - device = get_device(conf) if alone: model_sel = MSelOracleVisitor() else: @@ -107,7 +106,7 @@ def test_msel_oracle(): task = mk_task() model = mk_model(task) # make trainer for model - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) + exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=2) exp.execute(num_epochs=2) del exp @@ -120,7 +119,7 @@ def test_msel_oracle1(): task = mk_task() model = mk_model(task) exp = mk_exp( - task, model, trainer="mldg", test_domain="domain1", batchsize=32, alone=False + task, model, trainer="mldg", test_domain="domain1", batchsize=2, alone=False ) exp.execute(num_epochs=2) @@ -136,7 +135,7 @@ def test_msel_oracle2(): model = mk_model(task) # make trainer for model - exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=32) + exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=2) exp.execute(num_epochs=2) @@ -152,7 +151,7 @@ def test_msel_oracle3(): model, trainer="mldg", test_domain="domain1", - batchsize=32, + batchsize=2, alone=False, force_best_val=True, ) @@ -172,7 +171,7 @@ def test_msel_oracle4(): model, trainer="mldg", test_domain="domain1", - batchsize=32, + batchsize=2, alone=False, msel_loss_tr=True, ) From 235832abf4e185d67f3779d36ad64818c7f8d8a8 Mon Sep 17 00:00:00 2001 From: smilesun Date: Fri, 26 Jul 2024 15:21:53 +0200 Subject: [PATCH 15/18] split msel_oracle test --- tests/test_msel_oracle.py | 119 +------------------------------------ tests/test_msel_oracle2.py | 26 ++++++++ tests/utils_task_model.py | 94 +++++++++++++++++++++++++++++ 3 files changed, 121 insertions(+), 118 deletions(-) create mode 100644 tests/test_msel_oracle2.py create mode 100644 tests/utils_task_model.py diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 8bcf37a92..b8941f496 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -1,102 +1,7 @@ """ executing mk_exp multiple times will cause deep copy to be called multiple times, pytest will show process got killed. """ -from torch import nn -from torchvision import models as torchvisionmodels -from torchvision.models import ResNet50_Weights - -from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor -from domainlab.algos.msels.c_msel_val import MSelValPerf -from domainlab.algos.observers.b_obvisitor import ObVisitor -from domainlab.arg_parser import mk_parser_main -from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault -from domainlab.exp.exp_main import Exp -from domainlab.models.model_erm import mk_erm -from domainlab.tasks.task_dset import mk_task_dset -from domainlab.tasks.utils_task import ImSize -from domainlab.utils.utils_cuda import get_device - - -def mk_model(task): - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) - return model - -def mk_exp( - task, - model, - trainer: str, - test_domain: str, - batchsize: int, - alone=True, - force_best_val=False, - msel_loss_tr=False, -): - """ - Creates a custom experiment. The user can specify the input parameters. - - Input Parameters: - - task: create a task to a custom dataset by importing "mk_task_dset" - function from - "domainlab.tasks.task_dset". For more explanation on the input params - refer to the - documentation found in "domainlab.tasks.task_dset.py". - - model: create a model [NameOfModel] by importing "mk_[NameOfModel]" - function from - "domainlab.models.model_[NameOfModel]". For a concrete example and - explanation of the input - params refer to the documentation found in - "domainlab.models.model_[NameOfModel].py" - - trainer: string, - - test_domain: string, - - batch size: int - - Returns: experiment - """ - - str_arg = f"--model=apimodel --trainer={trainer} \ - --te_d={test_domain} --bs={batchsize}" - if msel_loss_tr: - str_arg = f"--model=apimodel --trainer={trainer} \ - --te_d={test_domain} --bs={batchsize} --msel=loss_tr" - - parser = mk_parser_main() - conf = parser.parse_args(str_arg.split()) - if alone: - model_sel = MSelOracleVisitor() - else: - model_sel = MSelOracleVisitor(MSelValPerf(max_es=0)) - if force_best_val: - model_sel.msel._best_val_acc = 1.0 - observer = ObVisitor(model_sel) - exp = Exp(conf, task, model=model, observer=observer) - model_sel.update(epoch=1, clear_counter=True) - return exp - - -def mk_task(): - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - return task - +from tests.utils_task_model import mk_exp, mk_model, mk_task def test_msel_oracle(): """ @@ -157,25 +62,3 @@ def test_msel_oracle3(): ) exp.execute(num_epochs=2) del exp - - -def test_msel_oracle4(): - """ - return trainer, model, observer - """ - task = mk_task() - model = mk_model(task) - # specify backbone to use - exp = mk_exp( - task, - model, - trainer="mldg", - test_domain="domain1", - batchsize=2, - alone=False, - msel_loss_tr=True, - ) - exp.execute(num_epochs=2) - exp.trainer.observer.model_sel.msel.best_loss = 0 - exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) - del exp diff --git a/tests/test_msel_oracle2.py b/tests/test_msel_oracle2.py new file mode 100644 index 000000000..0f3245305 --- /dev/null +++ b/tests/test_msel_oracle2.py @@ -0,0 +1,26 @@ +""" +executing mk_exp multiple times will cause deep copy to be called multiple times, pytest will show process got killed. +""" +from tests.utils_task_model import mk_exp, mk_model, mk_task + + +def test_msel_oracle4(): + """ + return trainer, model, observer + """ + task = mk_task() + model = mk_model(task) + # specify backbone to use + exp = mk_exp( + task, + model, + trainer="mldg", + test_domain="domain1", + batchsize=2, + alone=False, + msel_loss_tr=True, + ) + exp.execute(num_epochs=2) + exp.trainer.observer.model_sel.msel.best_loss = 0 + exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) + del exp diff --git a/tests/utils_task_model.py b/tests/utils_task_model.py new file mode 100644 index 000000000..4fa70796a --- /dev/null +++ b/tests/utils_task_model.py @@ -0,0 +1,94 @@ +from torch import nn +from torchvision import models as torchvisionmodels +from torchvision.models import ResNet50_Weights + +from domainlab.algos.msels.c_msel_oracle import MSelOracleVisitor +from domainlab.algos.msels.c_msel_val import MSelValPerf +from domainlab.algos.observers.b_obvisitor import ObVisitor +from domainlab.arg_parser import mk_parser_main +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault +from domainlab.exp.exp_main import Exp +from domainlab.models.model_erm import mk_erm +from domainlab.tasks.task_dset import mk_task_dset +from domainlab.tasks.utils_task import ImSize + + +def mk_model(task): + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + num_final_in = backbone.fc.in_features + backbone.fc = nn.Linear(num_final_in, task.dim_y) + + # specify model to use + model = mk_erm(list_str_y=task.list_str_y)(backbone) + return model + +def mk_exp( + task, + model, + trainer: str, + test_domain: str, + batchsize: int, + alone=True, + force_best_val=False, + msel_loss_tr=False, +): + """ + Creates a custom experiment. The user can specify the input parameters. + + Input Parameters: + - task: create a task to a custom dataset by importing "mk_task_dset" + function from + "domainlab.tasks.task_dset". For more explanation on the input params + refer to the + documentation found in "domainlab.tasks.task_dset.py". + - model: create a model [NameOfModel] by importing "mk_[NameOfModel]" + function from + "domainlab.models.model_[NameOfModel]". For a concrete example and + explanation of the input + params refer to the documentation found in + "domainlab.models.model_[NameOfModel].py" + - trainer: string, + - test_domain: string, + - batch size: int + + Returns: experiment + """ + + str_arg = f"--model=apimodel --trainer={trainer} \ + --te_d={test_domain} --bs={batchsize}" + if msel_loss_tr: + str_arg = f"--model=apimodel --trainer={trainer} \ + --te_d={test_domain} --bs={batchsize} --msel=loss_tr" + + parser = mk_parser_main() + conf = parser.parse_args(str_arg.split()) + if alone: + model_sel = MSelOracleVisitor() + else: + model_sel = MSelOracleVisitor(MSelValPerf(max_es=0)) + if force_best_val: + model_sel.msel._best_val_acc = 1.0 + observer = ObVisitor(model_sel) + exp = Exp(conf, task, model=model, observer=observer) + model_sel.update(epoch=1, clear_counter=True) + return exp + + +def mk_task(): + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain( + name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1), + ) + task.add_domain( + name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3), + ) + task.add_domain( + name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5), + ) + return task From 1feff92f8bbc327314ce70d67f696402b2f3f716 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Fri, 26 Jul 2024 17:06:50 +0200 Subject: [PATCH 16/18] Update test_msel_oracle.py --- tests/test_msel_oracle.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index b8941f496..3f2143c70 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -42,6 +42,7 @@ def test_msel_oracle2(): # make trainer for model exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=2) exp.execute(num_epochs=2) + del exp def test_msel_oracle3(): From d012b94fde3eb347c638e0af6055aca95e5c3413 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Fri, 26 Jul 2024 17:12:32 +0200 Subject: [PATCH 17/18] Update test_msel_oracle.py --- tests/test_msel_oracle.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 3f2143c70..03012bb18 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -13,7 +13,7 @@ def test_msel_oracle(): # make trainer for model exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=2) exp.execute(num_epochs=2) - + exp.clean_up() del exp @@ -29,6 +29,7 @@ def test_msel_oracle1(): exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) + exp.clean_up() del exp @@ -42,6 +43,7 @@ def test_msel_oracle2(): # make trainer for model exp = mk_exp(task, model, trainer="mldg", test_domain="domain1", batchsize=2) exp.execute(num_epochs=2) + exp.clean_up() del exp @@ -62,4 +64,5 @@ def test_msel_oracle3(): force_best_val=True, ) exp.execute(num_epochs=2) + exp.clean_up() del exp From 34975daeba869f9f88032fa99bfc1f250b48f51c Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Fri, 26 Jul 2024 17:12:56 +0200 Subject: [PATCH 18/18] Update test_msel_oracle2.py --- tests/test_msel_oracle2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_msel_oracle2.py b/tests/test_msel_oracle2.py index 0f3245305..651a38f2b 100644 --- a/tests/test_msel_oracle2.py +++ b/tests/test_msel_oracle2.py @@ -23,4 +23,5 @@ def test_msel_oracle4(): exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.best_loss = 0 exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) + exp.clean_up() del exp