From 188b5de6d554ce323956c22f25dea3d7232ad386 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 12:31:17 +0200 Subject: [PATCH 1/8] Update a_model_sel.py --- domainlab/algos/msels/a_model_sel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/domainlab/algos/msels/a_model_sel.py b/domainlab/algos/msels/a_model_sel.py index 1f20ff912..f6ae57799 100644 --- a/domainlab/algos/msels/a_model_sel.py +++ b/domainlab/algos/msels/a_model_sel.py @@ -25,7 +25,7 @@ def accept(self, trainer, tr_obs): self.tr_obs = tr_obs @abc.abstractmethod - def update(self): + def update(self, clear_counter=False): """ observer + visitor pattern to trainer if the best model should be updated From 5775c10624a70f822dabf8649560b8271f4e046e Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 12:42:35 +0200 Subject: [PATCH 2/8] Update c_msel_val.py --- domainlab/algos/msels/c_msel_val.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index 939cc47a0..01497f0c4 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -16,19 +16,19 @@ def __init__(self, max_es): self.best_te_metric = 0.0 super().__init__(max_es) # construct self.tr_obs (observer) - def update(self): + def update(self, clear_counter=False): """ if the best model should be updated """ flag = True if self.tr_obs.metric_val is None or self.tr_obs.str_msel == "loss_tr": - return super().update() + return super().update(clear_counter) metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel] if self.tr_obs.metric_te is not None: metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] self.best_te_metric = max(self.best_te_metric, metric_te_current) - if metric > self.best_val_acc: # observer + if metric > self.best_val_acc: # update hat{model} # different from loss, accuracy should be improved: the bigger the better self.best_val_acc = metric self.es_c = 0 # restore counter @@ -45,5 +45,7 @@ def update(self): f"corresponding to test acc: \ {self.sel_model_te_acc} / {self.best_te_metric}") flag = False # do not update best model - + if clear_counter: + logger.info("clearing counter") + self.es_c = 0 return flag From 23accdb6a5698118f7622f78f0c99171e58291d7 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 12:43:08 +0200 Subject: [PATCH 3/8] Update c_msel_tr_loss.py --- domainlab/algos/msels/c_msel_tr_loss.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/domainlab/algos/msels/c_msel_tr_loss.py b/domainlab/algos/msels/c_msel_tr_loss.py index c42f324b8..3b9d4581e 100644 --- a/domainlab/algos/msels/c_msel_tr_loss.py +++ b/domainlab/algos/msels/c_msel_tr_loss.py @@ -17,7 +17,7 @@ def __init__(self, max_es): self.max_es = max_es super().__init__() - def update(self): + def update(self, clear_counter=False): """ if the best model should be updated """ @@ -34,6 +34,9 @@ def update(self): logger.info(f"early stop counter: {self.es_c}") logger.info(f"loss:{loss}, best loss: {self.best_loss}") flag = False # do not update best model + if clear_counter: + logger.info("clearing counter") + self.es_c = 0 return flag def if_stop(self): From 9e80df565e415c5df9806d223dc01d78a3998d4c Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 12:43:59 +0200 Subject: [PATCH 4/8] Update c_msel_oracle.py --- domainlab/algos/msels/c_msel_oracle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/domainlab/algos/msels/c_msel_oracle.py b/domainlab/algos/msels/c_msel_oracle.py index eb672f46e..299a9e48b 100644 --- a/domainlab/algos/msels/c_msel_oracle.py +++ b/domainlab/algos/msels/c_msel_oracle.py @@ -18,7 +18,7 @@ def __init__(self, msel=None): self.best_oracle_acc = 0 self.msel = msel - def update(self): + def update(self, clear_counter=False): """ if the best model should be updated """ @@ -35,7 +35,7 @@ def update(self): logger.info("new oracle model saved") flag = True if self.msel is not None: - return self.msel.update() + return self.msel.update(clear_counter) return flag def if_stop(self): From 35b5ea38962a554ca4bd58ddf8ad5cac955776c8 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 16:41:45 +0200 Subject: [PATCH 5/8] Create test_observer.py --- tests/test_observer.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/test_observer.py diff --git a/tests/test_observer.py b/tests/test_observer.py new file mode 100644 index 000000000..e01de2422 --- /dev/null +++ b/tests/test_observer.py @@ -0,0 +1,28 @@ +""" +unit and end-end test for deep all, dann +""" +import os +import gc +import torch +from domainlab.compos.exp.exp_main import Exp +from domainlab.arg_parser import mk_parser_main +from tests.utils_test import utils_test_algo + + +def test_deepall(): + """ + unit deep all + """ + parser = mk_parser_main() + margs = parser.parse_args(["--te_d", "caltech", + "--task", "mini_vlcs", + "--aname", "deepall", "--bs", "2", + "--nname", "conv_bn_pool_2" + ]) + exp = Exp(margs) + exp.trainer.before_tr() + exp.trainer.tr_epoch(0) + exp.trainer.observer.update(True) + del exp + torch.cuda.empty_cache() + gc.collect() From ed209dd06457b42cb27d3a88f325afe243de074e Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 16:57:48 +0200 Subject: [PATCH 6/8] Update test_observer.py --- tests/test_observer.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_observer.py b/tests/test_observer.py index e01de2422..81dc4e118 100644 --- a/tests/test_observer.py +++ b/tests/test_observer.py @@ -26,3 +26,23 @@ def test_deepall(): del exp torch.cuda.empty_cache() gc.collect() + + +def test_deepall_trloss(): + """ + unit deep all + """ + parser = mk_parser_main() + margs = parser.parse_args(["--te_d", "caltech", + "--task", "mini_vlcs", + "--aname", "deepall", "--bs", "2", + "--nname", "conv_bn_pool_2", + "--msel", "loss_tr" + ]) + exp = Exp(margs) + exp.trainer.before_tr() + exp.trainer.tr_epoch(0) + exp.trainer.observer.update(True) + del exp + torch.cuda.empty_cache() + gc.collect() From 3cdc1e22543e83741dc5158fb69ab16331b25bf8 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 17:04:20 +0200 Subject: [PATCH 7/8] Update test_observer.py --- tests/test_observer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_observer.py b/tests/test_observer.py index 81dc4e118..46e2cfd87 100644 --- a/tests/test_observer.py +++ b/tests/test_observer.py @@ -1,12 +1,10 @@ """ unit and end-end test for deep all, dann """ -import os import gc import torch from domainlab.compos.exp.exp_main import Exp from domainlab.arg_parser import mk_parser_main -from tests.utils_test import utils_test_algo def test_deepall(): From d6b91273ce09bc4ac637fabae8b6308f952ee5a4 Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 17:43:46 +0200 Subject: [PATCH 8/8] Update ci_run_examples.sh --- ci_run_examples.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ci_run_examples.sh b/ci_run_examples.sh index 23249327f..3c00bef61 100644 --- a/ci_run_examples.sh +++ b/ci_run_examples.sh @@ -8,6 +8,8 @@ sed -n '/```shell/,/```/ p' docs/doc_examples.md | sed '/^```/ d' >> ./sh_temp_e bash -x -v -e sh_temp_example.sh echo "general examples done" +rm -r zoutput + echo "#!/bin/bash -x -v" > sh_temp_mnist.sh sed -n '/```shell/,/```/ p' docs/doc_MNIST_classification.md | sed '/^```/ d' >> ./sh_temp_mnist.sh bash -x -v -e sh_temp_mnist.sh