Skip to content

Commit

Permalink
Merge branch 'fbopt' into fbopt_fast_trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Oct 5, 2023
2 parents 8639fde + 59584b9 commit ac3cbbf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 41 deletions.
8 changes: 4 additions & 4 deletions domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, chain_node_builder,
zd_dim, zy_dim, zx_dim,
list_str_y, list_d_tr,
gamma_d, gamma_y,
beta_d, beta_x, beta_y, multiplier_recon=1.0):
beta_d, beta_x, beta_y, mu_recon=1.0):
"""
gamma: classification loss coefficient
"""
Expand Down Expand Up @@ -90,6 +90,7 @@ def hyper_update(self, epoch, fun_scheduler):
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.gamma_d = dict_rst["gamma_d"]
self.mu_recon = dict_rst["mu_recon"]

def hyper_init(self, functor_scheduler, trainer=None):
"""
Expand All @@ -99,7 +100,7 @@ def hyper_init(self, functor_scheduler, trainer=None):
"""
return functor_scheduler(
trainer=trainer,
mu_recon=self.multiplier_recon,
mu_recon=self.mu_recon,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
Expand Down Expand Up @@ -139,7 +140,6 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):

_, d_target = tensor_d.max(dim=1)
lc_d = F.cross_entropy(logit_d, d_target, reduction="none")

return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \
[self.multiplier_recon, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d]
[self.mu_recon, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d]
return ModelDIVA
11 changes: 7 additions & 4 deletions domainlab/models/model_hduva.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def hyper_update(self, epoch, fun_scheduler):
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.beta_t = dict_rst["beta_t"]
self.mu_recon = dict_rst["mu_recon"]

def hyper_init(self, functor_scheduler, trainer=None):
"""hyper_init.
Expand All @@ -78,8 +79,10 @@ def hyper_init(self, functor_scheduler, trainer=None):
# constructor signature is def __init__(self, **kwargs):
return functor_scheduler(
trainer=trainer,
mu_recon=self.multiplier_recon,
beta_d=self.beta_d, beta_y=self.beta_y, beta_x=self.beta_x,
mu_recon=self.mu_recon,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
beta_t=self.beta_t)

@store_args
Expand All @@ -92,7 +95,7 @@ def __init__(self, chain_node_builder,
device,
zx_dim=0,
topic_dim=3,
multiplier_recon=1.0):
mu_recon=1.0):
"""
"""
super().__init__(chain_node_builder,
Expand Down Expand Up @@ -165,7 +168,7 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
z_concat = self.decoder.concat_ytdx(zy_q, topic_q, zd_q, zx_q)
loss_recon_x, _, _ = self.decoder(z_concat, tensor_x)
return [loss_recon_x, zx_p_minus_q, zy_p_minus_zy_q, zd_p_minus_q, topic_p_minus_q], \
[self.multiplier_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t]
[self.mu_recon, -self.beta_x, -self.beta_y, -self.beta_d, -self.beta_t]

def extract_semantic_features(self, tensor_x):
"""
Expand Down
39 changes: 6 additions & 33 deletions examples/benchmark/benchmark_fbopt_mnist_diva.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ startseed: 0
endseed: 2

test_domains:
- 3
- 0


Expand All @@ -17,10 +16,13 @@ domainlab_args:
dmem: False
lr: 0.001
epos: 500
es: 100
es: 5
bs: 64
zx_dim: 0
zy_dim: 32
zd_dim: 32
gamma_d: 1.0
gamma_y: 1.0
nname: conv_bn_pool_2
nname_dom: conv_bn_pool_2
nname_topic_distrib_img2topic: conv_bn_pool_2
Expand All @@ -38,45 +40,20 @@ Shared params:
min: 0.9
max: 0.99
num: 3
step: 0.05
distribution: uniform

k_i_gain:
min: 0.0001
max: 0.01
num: 2
step: 0.0001
distribution: uniform
distribution: loguniform

init_mu4beta:
min: 0.01
min: 0.0001
max: 1.0
num: 5
distribution: uniform

gamma_y:
min: 1
max: 1e6
num: 3
step: 100
distribution: loguniform

zy_dim:
min: 32
max: 96
num: 2
step: 32
distribution: uniform
datatype: int

zd_dim:
min: 32
max: 96
num: 2
step: 32
distribution: uniform
datatype: int



# Test fbopt with different hyperparameter configurations
Expand All @@ -89,10 +66,6 @@ diva_fbopt:
- ini_setpoint_ratio
- k_i_gain
- init_mu4beta
- gamma_y
- zx_dim
- zy_dim
- zd_dim

erm:
aname: deepall

0 comments on commit ac3cbbf

Please sign in to comment.