From 5775c10624a70f822dabf8649560b8271f4e046e Mon Sep 17 00:00:00 2001 From: Xudong Sun Date: Sun, 8 Oct 2023 12:42:35 +0200 Subject: [PATCH] Update c_msel_val.py --- domainlab/algos/msels/c_msel_val.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/domainlab/algos/msels/c_msel_val.py b/domainlab/algos/msels/c_msel_val.py index 939cc47a0..01497f0c4 100644 --- a/domainlab/algos/msels/c_msel_val.py +++ b/domainlab/algos/msels/c_msel_val.py @@ -16,19 +16,19 @@ def __init__(self, max_es): self.best_te_metric = 0.0 super().__init__(max_es) # construct self.tr_obs (observer) - def update(self): + def update(self, clear_counter=False): """ if the best model should be updated """ flag = True if self.tr_obs.metric_val is None or self.tr_obs.str_msel == "loss_tr": - return super().update() + return super().update(clear_counter) metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel] if self.tr_obs.metric_te is not None: metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel] self.best_te_metric = max(self.best_te_metric, metric_te_current) - if metric > self.best_val_acc: # observer + if metric > self.best_val_acc: # update hat{model} # different from loss, accuracy should be improved: the bigger the better self.best_val_acc = metric self.es_c = 0 # restore counter @@ -45,5 +45,7 @@ def update(self): f"corresponding to test acc: \ {self.sel_model_te_acc} / {self.best_te_metric}") flag = False # do not update best model - + if clear_counter: + logger.info("clearing counter") + self.es_c = 0 return flag