Skip to content

Commit

Permalink
towards fix issue #507
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Nov 7, 2023
1 parent cdf0c56 commit bc9de7f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions domainlab/algos/trainers/train_fbopt_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,12 @@ def before_tr(self):
self.flag_setpoint_updated = False
if self.aconf.force_feedforward:
self.set_scheduler(scheduler=HyperSchedulerWarmup)
self.flag_update_hyper_per_epoch = True
self.hyper_scheduler.set_steps(total_steps=self.aconf.warmup)
else:
self.set_scheduler(scheduler=HyperSchedulerFeedback)

self.set_model_with_mu() # very small value
self.set_model_with_mu(0) # very small value
if self.aconf.tr_with_init_mu:
self.tr_with_init_mu()

Expand All @@ -113,12 +115,12 @@ def tr_with_init_mu(self):
"""
super().tr_epoch(-1)

def set_model_with_mu(self):
def set_model_with_mu(self, epoch):
"""
set model multipliers
"""
# self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu))
self.model.hyper_update(epoch=None, fun_scheduler=self.hyper_scheduler)
self.model.hyper_update(epoch=epoch, fun_scheduler=self.hyper_scheduler)

def tr_epoch(self, epoch, flag_info=False):
"""
Expand All @@ -130,7 +132,7 @@ def tr_epoch(self, epoch, flag_info=False):
self.epo_loss_tr,
self.list_str_multiplier_na,
miter=epoch)
self.set_model_with_mu()
self.set_model_with_mu(epoch)
if hasattr(self.model, "dict_multiplier"):
logger = Logger.get_logger()
logger.info(f"current multiplier: {self.model.dict_multiplier}")
Expand Down
2 changes: 1 addition & 1 deletion domainlab/models/model_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, list_str_y, list_str_d,
self.net_encoder = net_encoder
self.net_classifier = net_classifier
self.net_discriminator = net_discriminator

@property
def list_str_multiplier_na(self):
return ["alpha"]
Expand Down

0 comments on commit bc9de7f

Please sign in to comment.