Skip to content

Commit

Permalink
acc read from model, save oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Oct 2, 2023
1 parent 71be477 commit 2d5973c
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 10 deletions.
21 changes: 14 additions & 7 deletions domainlab/algos/msels/c_msel_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class MSelOracleVisitor(AMSel):
save best out-of-domain test acc model, but do not affect
how the final model is selected
"""
def __init__(self, msel):
def __init__(self, msel=None):
"""
Decorator pattern
"""
Expand All @@ -23,13 +23,20 @@ def update(self):
if the best model should be updated
"""
self.tr_obs.exp.visitor.save(self.trainer.model, "epoch")
if self.tr_obs.metric_te["acc"] > self.best_oracle_acc:
self.best_oracle_acc = self.tr_obs.metric_te["acc"]
# FIXME: only works for classification
self.tr_obs.exp.visitor.save(self.trainer.model, "oracle")
flag = False
metric = self.tr_obs.metric_te[self.tr_obs.str_msel]
if metric > self.best_oracle_acc:
self.best_oracle_acc = metric
if self.msel is not None:
self.tr_obs.exp.visitor.save(self.trainer.model, "oracle")
else:
self.tr_obs.exp.visitor.save(self.trainer.model)
logger = Logger.get_logger()
logger.info("oracle model saved")
return self.msel.update()
logger.info("new oracle model saved")
flag = True
if self.msel is not None:
return self.msel.update()
return flag

def if_stop(self):
"""
Expand Down
6 changes: 3 additions & 3 deletions domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def update(self):
flag = True
if self.tr_obs.metric_val is None or self.tr_obs.str_msel == "loss_tr":
return super().update()
if self.tr_obs.metric_val["acc"] > self.best_val_acc: # observer
metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
if metric > self.best_val_acc: # observer
# different from loss, accuracy should be improved: the bigger the better
self.best_val_acc = self.tr_obs.metric_val["acc"]
# FIXME: only works for classification
self.best_val_acc = metric
self.es_c = 0 # restore counter

else:
Expand Down
6 changes: 6 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def __init__(self, successor_node=None):
self.loader_tr_source_target = None
self.flag_initialized = False

@property
def str_metric4msel(self):
"""
metric for model selection
"""
return self.model.metric4msel

def init_business(self, model, task, observer, device, aconf, flag_accept=True):
"""
Expand Down
7 changes: 7 additions & 0 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ class AModel(nn.Module, metaclass=abc.ABCMeta):
"""
operations that all models (classification, segmentation, seq2seq)
"""
@property
def metric4msel(self):
"""
metric for model selection
"""
raise NotImplementedError

@property
def multiplier4task_loss(self):
"""
Expand Down
4 changes: 4 additions & 0 deletions domainlab/models/a_model_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ class AModelClassif(AModel, metaclass=abc.ABCMeta):
"""
match_feat_fun_na = "cal_logit_y"

@property
def metric4msel(self):
return "acc"

def create_perf_obj(self, task):
"""
for classification, dimension of target can be quieried from task
Expand Down

0 comments on commit 2d5973c

Please sign in to comment.