From 2d5973cd49d24b45ec42f8aa387d54e0d75799c7 Mon Sep 17 00:00:00 2001 From: smilesun Date: Mon, 2 Oct 2023 15:32:38 +0200 Subject: [PATCH] acc read from model, save oracle --- domainlab/algos/msels/c_msel_oracle.py | 21 ++++++++++++++------- domainlab/algos/msels/c_msel_val.py | 6 +++--- domainlab/algos/trainers/a_trainer.py | 6 ++++++ domainlab/models/a_model.py | 7 +++++++ domainlab/models/a_model_classif.py | 4 ++++ 5 files changed, 34 insertions(+), 10 deletions(-) diff --git a/domainlab/algos/msels/c_msel_oracle.py b/domainlab/algos/msels/c_msel_oracle.py index 421b7922b..eefc1d130 100644 --- a/domainlab/algos/msels/c_msel_oracle.py +++ b/domainlab/algos/msels/c_msel_oracle.py @@ -10,7 +10,7 @@ class MSelOracleVisitor(AMSel): save best out-of-domain test acc model, but do not affect how the final model is selected """ - def __init__(self, msel): + def __init__(self, msel=None): """ Decorator pattern """ @@ -23,13 +23,20 @@ def update(self): if the best model should be updated """ self.tr_obs.exp.visitor.save(self.trainer.model, "epoch") - if self.tr_obs.metric_te["acc"] > self.best_oracle_acc: - self.best_oracle_acc = self.tr_obs.metric_te["acc"] - # FIXME: only works for classification - self.tr_obs.exp.visitor.save(self.trainer.model, "oracle") + flag = False + metric = self.tr_obs.metric_te[self.tr_obs.str_msel] + if metric > self.best_oracle_acc: + self.best_oracle_acc = metric + if self.msel is not None: + self.tr_obs.exp.visitor.save(self.trainer.model, "oracle") + else: + self.tr_obs.exp.visitor.save(self.trainer.model) logger = Logger.get_logger() - logger.info("oracle model saved") - return self.msel.update() + logger.info("new oracle model saved") + flag = True + if self.msel is not None: + return self.msel.update() + return flag def if_stop(self): """ diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index 511bc3362..c41975252 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -21,10 +21,10 @@ def update(self): flag = True if self.tr_obs.metric_val is None or self.tr_obs.str_msel == "loss_tr": return super().update() - if self.tr_obs.metric_val["acc"] > self.best_val_acc: # observer + metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel] + if metric > self.best_val_acc: # observer # different from loss, accuracy should be improved: the bigger the better - self.best_val_acc = self.tr_obs.metric_val["acc"] - # FIXME: only works for classification + self.best_val_acc = metric self.es_c = 0 # restore counter else: diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 4dbe745cc..b95b23af7 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -52,6 +52,12 @@ def __init__(self, successor_node=None): self.loader_tr_source_target = None self.flag_initialized = False + @property + def str_metric4msel(self): + """ + metric for model selection + """ + return self.model.metric4msel def init_business(self, model, task, observer, device, aconf, flag_accept=True): """ diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index 7625d246a..ec8ad2f96 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -11,6 +11,13 @@ class AModel(nn.Module, metaclass=abc.ABCMeta): """ operations that all models (classification, segmentation, seq2seq) """ + @property + def metric4msel(self): + """ + metric for model selection + """ + raise NotImplementedError + @property def multiplier4task_loss(self): """ diff --git a/domainlab/models/a_model_classif.py b/domainlab/models/a_model_classif.py index 2e3d20b13..74bc33268 100644 --- a/domainlab/models/a_model_classif.py +++ b/domainlab/models/a_model_classif.py @@ -24,6 +24,10 @@ class AModelClassif(AModel, metaclass=abc.ABCMeta): """ match_feat_fun_na = "cal_logit_y" + @property + def metric4msel(self): + return "acc" + def create_perf_obj(self, task): """ for classification, dimension of target can be quieried from task