Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed May 7, 2024
1 parent c2eafc5 commit 0fca21d
Show file tree
Hide file tree
Showing 42 changed files with 772 additions and 220 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@
),
isize=ImSize(3, 224, 224),
dict_domain2imgroot={
"caltech": os.path.join(path_this_file, "../../data/vlcs_mini/caltech/"),
"sun": os.path.join(path_this_file, "../../data/vlcs_mini/sun/"),
"labelme": os.path.join(path_this_file, "../../data/vlcs_mini/labelme/"),
"caltech": os.path.join(path_this_file, "../../domainlab/zdata/vlcs_mini/caltech/"),
"sun": os.path.join(path_this_file, "../../domainlab/zdata/vlcs_mini/sun/"),
"labelme": os.path.join(path_this_file, "../../domainlab/zdata/vlcs_mini/labelme/"),
},
taskna="e_mini_vlcs",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ <h1 id="modules-domainlab-algos-builder-api-model--page-root">Source code for do
"""
args = exp.args
device = get_device(args)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
observer = ObVisitor(model_sel)
trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler")
return trainer, None, observer, device</div></div>
Expand Down
3 changes: 2 additions & 1 deletion docs/build/html/_modules/domainlab/algos/builder_custom.html
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ <h1 id="modules-domainlab-algos-builder-custom--page-root">Source code for domai
task = exp.task
args = exp.args
device = get_device(args)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es),
val_threshold=args.val_threshold)
observer = ObVisitor(model_sel)
model = class_name_model(net_classifier=None, list_str_y=task.list_str_y)
model = self.init_next_model(model, exp)
Expand Down
2 changes: 1 addition & 1 deletion docs/build/html/_modules/domainlab/algos/builder_dann.html
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ <h1 id="modules-domainlab-algos-builder-dann--page-root">Source code for domainl
args = exp.args
task.get_list_domains_tr_te(args.tr_d, args.te_d)
device = get_device(args)
msel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
msel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
observer = ObVisitor(msel)
observer = ObVisitorCleanUp(observer)

Expand Down
2 changes: 1 addition & 1 deletion docs/build/html/_modules/domainlab/algos/builder_diva.html
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ <h1 id="modules-domainlab-algos-builder-diva--page-root">Source code for domainl
beta_d=args.beta_d,
)
device = get_device(args)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
if not args.gen:
observer = ObVisitor(model_sel)
else:
Expand Down
2 changes: 1 addition & 1 deletion docs/build/html/_modules/domainlab/algos/builder_erm.html
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ <h1 id="modules-domainlab-algos-builder-erm--page-root">Source code for domainla
task = exp.task
args = exp.args
device = get_device(args)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
observer = ObVisitor(model_sel)

builder = FeatExtractNNBuilderChainNodeGetter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ <h1 id="modules-domainlab-algos-builder-hduva--page-root">Source code for domain
beta_d=args.beta_d,
)
model = self.init_next_model(model, exp)
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es))
model_sel = MSelOracleVisitor(MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
observer = ObVisitorCleanUp(ObVisitor(model_sel))
trainer = TrainerChainNodeGetter(args.trainer)(default="hyperscheduler")
trainer.init_business(model, task, observer, device, args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ <h1 id="modules-domainlab-algos-builder-jigen1--page-root">Source code for domai
task = exp.task
args = exp.args
device = get_device(args)
msel = MSelOracleVisitor(msel=MSelValPerf(max_es=args.es))
msel = MSelOracleVisitor(msel=MSelValPerf(max_es=args.es), val_threshold=args.val_threshold)
observer = ObVisitor(msel)
observer = ObVisitorCleanUp(observer)

Expand Down
64 changes: 51 additions & 13 deletions docs/build/html/_modules/domainlab/algos/msels/a_model_sel.html
Original file line number Diff line number Diff line change
Expand Up @@ -319,14 +319,16 @@ <h1 id="modules-domainlab-algos-msels-a-model-sel--page-root">Source code for do
Abstract Model Selection
"""

def __init__(self):
def __init__(self, val_threshold = None):
"""
trainer and tr_observer
"""
self.trainer = None
self._tr_obs = None
self._observer = None
self.msel = None
self._max_es = None
self._model_selection_epoch = None
self._val_threshold = val_threshold

<div class="viewcode-block" id="AMSel.reset"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.reset">[docs]</a> def reset(self):
"""
Expand All @@ -336,11 +338,11 @@ <h1 id="modules-domainlab-algos-msels-a-model-sel--page-root">Source code for do
self.msel.reset()</div>

@property
def tr_obs(self):
def observer4msel(self):
"""
the observer from trainer
"""
return self._tr_obs
return self._observer

@property
def max_es(self):
Expand All @@ -353,35 +355,55 @@ <h1 id="modules-domainlab-algos-msels-a-model-sel--page-root">Source code for do
return self.msel.max_es
return self._max_es

<div class="viewcode-block" id="AMSel.accept"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.accept">[docs]</a> def accept(self, trainer, tr_obs):
<div class="viewcode-block" id="AMSel.accept"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.accept">[docs]</a> def accept(self, trainer, observer4msel):
"""
Visitor pattern to trainer
accept trainer and tr_observer
"""
self.trainer = trainer
self._tr_obs = tr_obs
self._observer = observer4msel
if self.msel is not None:
self.msel.accept(trainer, tr_obs)</div>
self.msel.accept(trainer, observer4msel)</div>

<div class="viewcode-block" id="AMSel.update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.update">[docs]</a> @abc.abstractmethod
def update(self, clear_counter=False):
<div class="viewcode-block" id="AMSel.update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.update">[docs]</a> def update(self, epoch, clear_counter=False):
"""
level above the observer + visitor pattern to get information about the epoch
"""
update = self.base_update(clear_counter)
if update:
self._model_selection_epoch = epoch

return update</div>

<div class="viewcode-block" id="AMSel.base_update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.base_update">[docs]</a> @abc.abstractmethod
def base_update(self, clear_counter=False):
"""
observer + visitor pattern to trainer
if the best model should be updated
return boolean
"""</div>

<div class="viewcode-block" id="AMSel.if_stop"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.if_stop">[docs]</a> def if_stop(self):
<div class="viewcode-block" id="AMSel.if_stop"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.if_stop">[docs]</a> def if_stop(self, acc_val = None):
"""
check if trainer should stop
check if trainer should stop and additionally tests for validation threshold
return boolean
"""
# NOTE: since if_stop is not abstract, one has to
# be careful to always override it in child class
# only if the child class has a decorator which will
# dispatched.
if self.msel is not None and acc_val is not None:
if self._val_threshold is not None and acc_val &lt; self._val_threshold:
return False
return self.early_stop()</div>

<div class="viewcode-block" id="AMSel.early_stop"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.a_model_sel.AMSel.early_stop">[docs]</a> def early_stop(self):
"""
check if trainer should stop
return boolean
"""
if self.msel is not None:
return self.msel.if_stop()
return self.msel.early_stop()
raise NotImplementedError</div>

@property
Expand Down Expand Up @@ -409,7 +431,23 @@ <h1 id="modules-domainlab-algos-msels-a-model-sel--page-root">Source code for do
"""
if self.msel is not None:
return self.msel.sel_model_te_acc
return -1</div>
return -1

@property
def model_selection_epoch(self):
"""
the epoch when the model was selected
"""
if self._model_selection_epoch is not None:
return self._model_selection_epoch
return -1

@property
def val_threshold(self):
"""
the treshold below which we don't stop early
"""
return self._val_threshold</div>
</pre></div>

</article>
Expand Down
24 changes: 12 additions & 12 deletions docs/build/html/_modules/domainlab/algos/msels/c_msel_oracle.html
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ <h1 id="modules-domainlab-algos-msels-c-msel-oracle--page-root">Source code for
how the final model is selected
"""

def __init__(self, msel=None):
def __init__(self, msel=None, val_threshold = None):
"""
Decorator pattern
"""
super().__init__()
super().__init__(val_threshold)
self.best_oracle_acc = 0
self.msel = msel

Expand All @@ -339,15 +339,15 @@ <h1 id="modules-domainlab-algos-msels-c-msel-oracle--page-root">Source code for
return self.msel.oracle_last_setpoint_sel_te_acc
return -1

<div class="viewcode-block" id="MSelOracleVisitor.update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_oracle.MSelOracleVisitor.update">[docs]</a> def update(self, clear_counter=False):
<div class="viewcode-block" id="MSelOracleVisitor.base_update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_oracle.MSelOracleVisitor.base_update">[docs]</a> def base_update(self, clear_counter=False):
"""
if the best model should be updated
"""
self.trainer.model.save("epoch")
flag = False
if self.tr_obs.metric_val is None:
return super().update(clear_counter)
metric = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
if self.observer4msel.metric_val is None:
return super().base_update(clear_counter)
metric = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
if metric &gt; self.best_oracle_acc:
self.best_oracle_acc = metric
if self.msel is not None:
Expand All @@ -358,23 +358,23 @@ <h1 id="modules-domainlab-algos-msels-c-msel-oracle--page-root">Source code for
logger.info("new oracle model saved")
flag = True
if self.msel is not None:
return self.msel.update(clear_counter)
return self.msel.base_update(clear_counter)
return flag</div>

<div class="viewcode-block" id="MSelOracleVisitor.if_stop"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_oracle.MSelOracleVisitor.if_stop">[docs]</a> def if_stop(self):
<div class="viewcode-block" id="MSelOracleVisitor.early_stop"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_oracle.MSelOracleVisitor.early_stop">[docs]</a> def early_stop(self):
"""
if should early stop
oracle model selection does not intervene how models get selected
by the innermost model selection
"""
if self.msel is not None:
return self.msel.if_stop()
return self.msel.early_stop()
return False</div>

<div class="viewcode-block" id="MSelOracleVisitor.accept"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_oracle.MSelOracleVisitor.accept">[docs]</a> def accept(self, trainer, tr_obs):
<div class="viewcode-block" id="MSelOracleVisitor.accept"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_oracle.MSelOracleVisitor.accept">[docs]</a> def accept(self, trainer, observer4msel):
if self.msel is not None:
self.msel.accept(trainer, tr_obs)
super().accept(trainer, tr_obs)</div></div>
self.msel.accept(trainer, observer4msel)
super().accept(trainer, observer4msel)</div></div>
</pre></div>

</article>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@

<h1 id="modules-domainlab-algos-msels-c-msel-tr-loss--page-root">Source code for domainlab.algos.msels.c_msel_tr_loss</h1><div class="highlight"><pre>
<span></span>"""
Model Selection should be decoupled from
AMSel.accept ---&gt; Trainer
"""
import math

Expand All @@ -322,8 +322,8 @@ <h1 id="modules-domainlab-algos-msels-c-msel-tr-loss--page-root">Source code for
2. Visitor pattern to trainer
"""

def __init__(self, max_es):
super().__init__()
def __init__(self, max_es, val_threshold = None):
super().__init__(val_threshold)
# NOTE: super() must come first otherwise it will overwrite existing
# values!
self.reset()
Expand All @@ -337,7 +337,7 @@ <h1 id="modules-domainlab-algos-msels-c-msel-tr-loss--page-root">Source code for
def max_es(self):
return self._max_es

<div class="viewcode-block" id="MSelTrLoss.update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_tr_loss.MSelTrLoss.update">[docs]</a> def update(self, clear_counter=False):
<div class="viewcode-block" id="MSelTrLoss.base_update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_tr_loss.MSelTrLoss.base_update">[docs]</a> def base_update(self, clear_counter=False):
"""
if the best model should be updated
"""
Expand All @@ -359,7 +359,7 @@ <h1 id="modules-domainlab-algos-msels-c-msel-tr-loss--page-root">Source code for
self.es_c = 0
return flag</div>

<div class="viewcode-block" id="MSelTrLoss.if_stop"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_tr_loss.MSelTrLoss.if_stop">[docs]</a> def if_stop(self):
<div class="viewcode-block" id="MSelTrLoss.early_stop"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_tr_loss.MSelTrLoss.early_stop">[docs]</a> def early_stop(self):
"""
if should early stop
"""
Expand Down
22 changes: 11 additions & 11 deletions docs/build/html/_modules/domainlab/algos/msels/c_msel_val.html
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ <h1 id="modules-domainlab-algos-msels-c-msel-val--page-root">Source code for dom
2. Visitor pattern to trainer
"""

def __init__(self, max_es):
super().__init__(max_es) # construct self.tr_obs (observer)
def __init__(self, max_es, val_threshold = None):
super().__init__(max_es, val_threshold) # construct self.observer4msel (observer)
self.reset()

<div class="viewcode-block" id="MSelValPerf.reset"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_val.MSelValPerf.reset">[docs]</a> def reset(self):
Expand All @@ -348,33 +348,33 @@ <h1 id="modules-domainlab-algos-msels-c-msel-val--page-root">Source code for dom
"""
return self._best_te_metric

<div class="viewcode-block" id="MSelValPerf.update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_val.MSelValPerf.update">[docs]</a> def update(self, clear_counter=False):
<div class="viewcode-block" id="MSelValPerf.base_update"><a class="viewcode-back" href="../../../../domainlab.algos.msels.html#domainlab.algos.msels.c_msel_val.MSelValPerf.base_update">[docs]</a> def base_update(self, clear_counter=False):
"""
if the best model should be updated
"""
flag = True
if self.tr_obs.metric_val is None:
return super().update(clear_counter)
metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
if self.tr_obs.metric_te is not None:
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
if self.observer4msel.metric_val is None:
return super().base_update(clear_counter)
metric = self.observer4msel.metric_val[self.observer4msel.str_metric4msel]
if self.observer4msel.metric_te is not None:
metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
self._best_te_metric = max(self._best_te_metric, metric_te_current)

if metric &gt; self._best_val_acc: # update hat{model}
# different from loss, accuracy should be improved:
# the bigger the better
self._best_val_acc = metric
self.es_c = 0 # restore counter
if self.tr_obs.metric_te is not None:
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
if self.observer4msel.metric_te is not None:
metric_te_current = self.observer4msel.metric_te[self.observer4msel.str_metric4msel]
self._sel_model_te_acc = metric_te_current

else:
self.es_c += 1
logger = Logger.get_logger()
logger.info(f"early stop counter: {self.es_c}")
logger.info(
f"val acc:{self.tr_obs.metric_val['acc']}, "
f"val acc:{self.observer4msel.metric_val['acc']}, "
+ f"best validation acc: {self.best_val_acc}, "
+ f"corresponding to test acc: \
{self.sel_model_te_acc} / {self.best_te_metric}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,12 @@ <h1 id="modules-domainlab-algos-observers-b-obvisitor--page-root">Source code fo
self.loader_te, self.device
)
self.metric_te = metric_te
if self.model_sel.update():
if self.model_sel.update(epoch):
logger.info("better model found")
self.host_trainer.model.save()
logger.info("persisted")
flag_stop = self.model_sel.if_stop()
acc = self.metric_te.get("acc")
flag_stop = self.model_sel.if_stop(acc)
flag_enough = epoch &gt;= self.host_trainer.aconf.epos_min
return flag_stop &amp; flag_enough</div>

Expand Down Expand Up @@ -411,8 +412,10 @@ <h1 id="modules-domainlab-algos-observers-b-obvisitor--page-root">Source code fo
metric_te.update({"acc_oracle": -1})
if hasattr(self, "model_sel"):
metric_te.update({"acc_val": self.model_sel.best_val_acc})
metric_te.update({"model_selection_epoch": self.model_sel.model_selection_epoch})
else:
metric_te.update({"acc_val": -1})
metric_te.update({"model_selection_epoch": -1})
self.dump_prediction(model_ld, metric_te)
# save metric to one line in csv result file
self.host_trainer.model.visitor(metric_te)</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ <h1 id="modules-domainlab-algos-trainers-a-trainer--page-root">Source code for d
self.device = None
self.aconf = None
#
self.dict_loader_tr = None
self.loader_tr = None
self.loader_te = None
self.num_batches = None
Expand Down Expand Up @@ -443,6 +444,7 @@ <h1 id="modules-domainlab-algos-trainers-a-trainer--page-root">Source code for d
self.device = device
self.aconf = aconf
#
self.dict_loader_tr = task.dict_loader_tr
self.loader_tr = task.loader_tr
self.loader_te = task.loader_te

Expand Down
Loading

0 comments on commit 0fca21d

Please sign in to comment.