Skip to content

Commit

Permalink
Merge branch 'master' into fbopt
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Oct 8, 2023
2 parents 21878f5 + d6b9127 commit d3eb25c
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 8 deletions.
2 changes: 2 additions & 0 deletions ci_run_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ sed -n '/```shell/,/```/ p' docs/doc_examples.md | sed '/^```/ d' >> ./sh_temp_e
bash -x -v -e sh_temp_example.sh
echo "general examples done"

rm -r zoutput

echo "#!/bin/bash -x -v" > sh_temp_mnist.sh
sed -n '/```shell/,/```/ p' docs/doc_MNIST_classification.md | sed '/^```/ d' >> ./sh_temp_mnist.sh
bash -x -v -e sh_temp_mnist.sh
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/msels/a_model_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def accept(self, trainer, tr_obs):
self.tr_obs = tr_obs

@abc.abstractmethod
def update(self):
def update(self, clear_counter=False):
"""
observer + visitor pattern to trainer
if the best model should be updated
Expand Down
4 changes: 2 additions & 2 deletions domainlab/algos/msels/c_msel_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, msel=None):
self.best_oracle_acc = 0
self.msel = msel

def update(self):
def update(self, clear_counter=False):
"""
if the best model should be updated
"""
Expand All @@ -35,7 +35,7 @@ def update(self):
logger.info("new oracle model saved")
flag = True
if self.msel is not None:
return self.msel.update()
return self.msel.update(clear_counter)
return flag

def if_stop(self):
Expand Down
5 changes: 4 additions & 1 deletion domainlab/algos/msels/c_msel_tr_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, max_es):
self.max_es = max_es
super().__init__()

def update(self):
def update(self, clear_counter=False):
"""
if the best model should be updated
"""
Expand All @@ -34,6 +34,9 @@ def update(self):
logger.info(f"early stop counter: {self.es_c}")
logger.info(f"loss:{loss}, best loss: {self.best_loss}")
flag = False # do not update best model
if clear_counter:
logger.info("clearing counter")
self.es_c = 0
return flag

def if_stop(self):
Expand Down
10 changes: 6 additions & 4 deletions domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def __init__(self, max_es):
self.best_te_metric = 0.0
super().__init__(max_es) # construct self.tr_obs (observer)

def update(self):
def update(self, clear_counter=False):
"""
if the best model should be updated
"""
flag = True
if self.tr_obs.metric_val is None or self.tr_obs.str_msel == "loss_tr":
return super().update()
return super().update(clear_counter)
metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
if self.tr_obs.metric_te is not None:
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
self.best_te_metric = max(self.best_te_metric, metric_te_current)

if metric > self.best_val_acc: # observer
if metric > self.best_val_acc: # update hat{model}
# different from loss, accuracy should be improved: the bigger the better
self.best_val_acc = metric
self.es_c = 0 # restore counter
Expand All @@ -45,5 +45,7 @@ def update(self):
f"corresponding to test acc: \
{self.sel_model_te_acc} / {self.best_te_metric}")
flag = False # do not update best model

if clear_counter:
logger.info("clearing counter")
self.es_c = 0
return flag
46 changes: 46 additions & 0 deletions tests/test_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
unit and end-end test for deep all, dann
"""
import gc
import torch
from domainlab.compos.exp.exp_main import Exp
from domainlab.arg_parser import mk_parser_main


def test_deepall():
"""
unit deep all
"""
parser = mk_parser_main()
margs = parser.parse_args(["--te_d", "caltech",
"--task", "mini_vlcs",
"--aname", "deepall", "--bs", "2",
"--nname", "conv_bn_pool_2"
])
exp = Exp(margs)
exp.trainer.before_tr()
exp.trainer.tr_epoch(0)
exp.trainer.observer.update(True)
del exp
torch.cuda.empty_cache()
gc.collect()


def test_deepall_trloss():
"""
unit deep all
"""
parser = mk_parser_main()
margs = parser.parse_args(["--te_d", "caltech",
"--task", "mini_vlcs",
"--aname", "deepall", "--bs", "2",
"--nname", "conv_bn_pool_2",
"--msel", "loss_tr"
])
exp = Exp(margs)
exp.trainer.before_tr()
exp.trainer.tr_epoch(0)
exp.trainer.observer.update(True)
del exp
torch.cuda.empty_cache()
gc.collect()

0 comments on commit d3eb25c

Please sign in to comment.