Skip to content

Commit

Permalink
Update c_msel_val.py
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Oct 8, 2023
1 parent 188b5de commit 5775c10
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def __init__(self, max_es):
self.best_te_metric = 0.0
super().__init__(max_es) # construct self.tr_obs (observer)

def update(self):
def update(self, clear_counter=False):
"""
if the best model should be updated
"""
flag = True
if self.tr_obs.metric_val is None or self.tr_obs.str_msel == "loss_tr":
return super().update()
return super().update(clear_counter)
metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
if self.tr_obs.metric_te is not None:
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
self.best_te_metric = max(self.best_te_metric, metric_te_current)

if metric > self.best_val_acc: # observer
if metric > self.best_val_acc: # update hat{model}
# different from loss, accuracy should be improved: the bigger the better
self.best_val_acc = metric
self.es_c = 0 # restore counter
Expand All @@ -45,5 +45,7 @@ def update(self):
f"corresponding to test acc: \
{self.sel_model_te_acc} / {self.best_te_metric}")
flag = False # do not update best model

if clear_counter:
logger.info("clearing counter")
self.es_c = 0
return flag

0 comments on commit 5775c10

Please sign in to comment.