From bc9de7f1e65142cc04c8fbbdd6a626431c8fdf25 Mon Sep 17 00:00:00 2001 From: smilesun Date: Tue, 7 Nov 2023 15:50:51 +0100 Subject: [PATCH] towards fix issue #507 --- domainlab/algos/trainers/train_fbopt_b.py | 10 ++++++---- domainlab/models/model_dann.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/domainlab/algos/trainers/train_fbopt_b.py b/domainlab/algos/trainers/train_fbopt_b.py index a478d0880..ac1290703 100644 --- a/domainlab/algos/trainers/train_fbopt_b.py +++ b/domainlab/algos/trainers/train_fbopt_b.py @@ -86,10 +86,12 @@ def before_tr(self): self.flag_setpoint_updated = False if self.aconf.force_feedforward: self.set_scheduler(scheduler=HyperSchedulerWarmup) + self.flag_update_hyper_per_epoch = True + self.hyper_scheduler.set_steps(total_steps=self.aconf.warmup) else: self.set_scheduler(scheduler=HyperSchedulerFeedback) - self.set_model_with_mu() # very small value + self.set_model_with_mu(0) # very small value if self.aconf.tr_with_init_mu: self.tr_with_init_mu() @@ -113,12 +115,12 @@ def tr_with_init_mu(self): """ super().tr_epoch(-1) - def set_model_with_mu(self): + def set_model_with_mu(self, epoch): """ set model multipliers """ # self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu)) - self.model.hyper_update(epoch=None, fun_scheduler=self.hyper_scheduler) + self.model.hyper_update(epoch=epoch, fun_scheduler=self.hyper_scheduler) def tr_epoch(self, epoch, flag_info=False): """ @@ -130,7 +132,7 @@ def tr_epoch(self, epoch, flag_info=False): self.epo_loss_tr, self.list_str_multiplier_na, miter=epoch) - self.set_model_with_mu() + self.set_model_with_mu(epoch) if hasattr(self.model, "dict_multiplier"): logger = Logger.get_logger() logger.info(f"current multiplier: {self.model.dict_multiplier}") diff --git a/domainlab/models/model_dann.py b/domainlab/models/model_dann.py index 292acb8b5..574f9bacc 100644 --- a/domainlab/models/model_dann.py +++ b/domainlab/models/model_dann.py @@ -57,7 +57,7 @@ def __init__(self, list_str_y, list_str_d, self.net_encoder = net_encoder self.net_classifier = net_classifier self.net_discriminator = net_discriminator - + @property def list_str_multiplier_na(self): return ["alpha"]