From bd14244c7e123e02834641d21ea4db94ca60945b Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Oct 2023 14:09:57 +0200 Subject: [PATCH 1/3] add mu_recon to adapt --- domainlab/models/model_diva.py | 1 + .../benchmark/benchmark_fbopt_mnist_diva.yaml | 36 +++---------------- 2 files changed, 6 insertions(+), 31 deletions(-) diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index d250b75ab..ea8c083d8 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -90,6 +90,7 @@ def hyper_update(self, epoch, fun_scheduler): self.beta_y = dict_rst["beta_y"] self.beta_x = dict_rst["beta_x"] self.gamma_d = dict_rst["gamma_d"] + self.mu_recon = dict_rst["mu_recon"] def hyper_init(self, functor_scheduler, trainer=None): """ diff --git a/examples/benchmark/benchmark_fbopt_mnist_diva.yaml b/examples/benchmark/benchmark_fbopt_mnist_diva.yaml index a01245b50..c532050a8 100644 --- a/examples/benchmark/benchmark_fbopt_mnist_diva.yaml +++ b/examples/benchmark/benchmark_fbopt_mnist_diva.yaml @@ -20,7 +20,10 @@ domainlab_args: es: 100 bs: 64 zx_dim: 0 + zy_dim: 32 + zd_dim: 32 gamma_d: 1.0 + gamma_y: 1.0 nname: conv_bn_pool_2 nname_dom: conv_bn_pool_2 nname_topic_distrib_img2topic: conv_bn_pool_2 @@ -38,45 +41,20 @@ Shared params: min: 0.9 max: 0.99 num: 3 - step: 0.05 distribution: uniform k_i_gain: min: 0.0001 max: 0.01 num: 2 - step: 0.0001 - distribution: uniform + distribution: loguniform init_mu4beta: - min: 0.01 + min: 0.0001 max: 1.0 num: 5 - distribution: uniform - - gamma_y: - min: 1 - max: 1e6 - num: 3 - step: 100 distribution: loguniform - zy_dim: - min: 32 - max: 96 - num: 2 - step: 32 - distribution: uniform - datatype: int - - zd_dim: - min: 32 - max: 96 - num: 2 - step: 32 - distribution: uniform - datatype: int - # Test fbopt with different hyperparameter configurations @@ -89,10 +67,6 @@ diva_fbopt: - ini_setpoint_ratio - k_i_gain - init_mu4beta - - gamma_y - - zx_dim - - zy_dim - - zd_dim erm: aname: deepall From 7711dbbb266e5dca9e0d13e169b14eda2a30acc5 Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Oct 2023 14:19:01 +0200 Subject: [PATCH 2/3] recon into dynamic tune --- domainlab/models/model_diva.py | 7 +++---- domainlab/models/model_hduva.py | 11 +++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index ea8c083d8..752d5acf2 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -61,7 +61,7 @@ def __init__(self, chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y, list_d_tr, gamma_d, gamma_y, - beta_d, beta_x, beta_y, multiplier_recon=1.0): + beta_d, beta_x, beta_y, mu_recon=1.0): """ gamma: classification loss coefficient """ @@ -100,7 +100,7 @@ def hyper_init(self, functor_scheduler, trainer=None): """ return functor_scheduler( trainer=trainer, - mu_recon=self.multiplier_recon, + mu_recon=self.mu_recon, beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x, @@ -140,7 +140,6 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): _, d_target = tensor_d.max(dim=1) lc_d = F.cross_entropy(logit_d, d_target, reduction="none") - return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \ - [self.multiplier_recon, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d] + [self.mu_recon, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d] return ModelDIVA diff --git a/domainlab/models/model_hduva.py b/domainlab/models/model_hduva.py index c57c4331b..6d311bc57 100644 --- a/domainlab/models/model_hduva.py +++ b/domainlab/models/model_hduva.py @@ -68,6 +68,7 @@ def hyper_update(self, epoch, fun_scheduler): self.beta_y = dict_rst["beta_y"] self.beta_x = dict_rst["beta_x"] self.beta_t = dict_rst["beta_t"] + self.mu_recon = dict_rst["mu_recon"] def hyper_init(self, functor_scheduler, trainer=None): """hyper_init. @@ -78,8 +79,10 @@ def hyper_init(self, functor_scheduler, trainer=None): # constructor signature is def __init__(self, **kwargs): return functor_scheduler( trainer=trainer, - mu_recon=self.multiplier_recon, - beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x, + mu_recon=self.mu_recon, + beta_d=self.beta_d, + beta_y=self.beta_y, + beta_x=self.beta_x, beta_t=self.beta_t) @store_args @@ -92,7 +95,7 @@ def __init__(self, chain_node_builder, device, zx_dim=0, topic_dim=3, - multiplier_recon=1.0): + mu_recon=1.0): """ """ super().__init__(chain_node_builder, @@ -165,7 +168,7 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None): z_concat = self.decoder.concat_ytdx(zy_q, topic_q, zd_q, zx_q) loss_recon_x, _, _ = self.decoder(z_concat, tensor_x) return [loss_recon_x, zx_p_minus_q, zy_p_minus_zy_q, zd_p_minus_q, topic_p_minus_q], \ - [self.multiplier_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t] + [self.mu_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t] def extract_semantic_features(self, tensor_x): """ From 59584b9d76266224334dea9c0c45258ca1fc05cd Mon Sep 17 00:00:00 2001 From: smilesun Date: Thu, 5 Oct 2023 15:34:42 +0200 Subject: [PATCH 3/3] small early stop diva --- examples/benchmark/benchmark_fbopt_mnist_diva.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/benchmark/benchmark_fbopt_mnist_diva.yaml b/examples/benchmark/benchmark_fbopt_mnist_diva.yaml index c532050a8..8ae3784ac 100644 --- a/examples/benchmark/benchmark_fbopt_mnist_diva.yaml +++ b/examples/benchmark/benchmark_fbopt_mnist_diva.yaml @@ -7,7 +7,6 @@ startseed: 0 endseed: 2 test_domains: - - 3 - 0 @@ -17,7 +16,7 @@ domainlab_args: dmem: False lr: 0.001 epos: 500 - es: 100 + es: 5 bs: 64 zx_dim: 0 zy_dim: 32