Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fbopt_depreciated #353

Closed
wants to merge 254 commits into from
Closed
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
254 commits
Select commit Hold shift + click to select a range
10fb1fe
.
smilesun Sep 11, 2023
61582d8
.
smilesun Sep 11, 2023
fb57322
.
smilesun Sep 11, 2023
72f7b7d
.
smilesun Sep 11, 2023
69d8858
.
smilesun Sep 11, 2023
6164bd8
cuda out of memory
smilesun Sep 11, 2023
e8c60d5
.
smilesun Sep 11, 2023
eee027b
.
smilesun Sep 11, 2023
273e24f
.
smilesun Sep 11, 2023
d4a80c6
.
smilesun Sep 11, 2023
3509017
.
smilesun Sep 11, 2023
0863481
.
smilesun Sep 11, 2023
6c43cf3
.
smilesun Sep 11, 2023
97eb62a
.
smilesun Sep 11, 2023
901bc81
.
smilesun Sep 11, 2023
ed5945c
.
smilesun Sep 11, 2023
f950ed5
.
smilesun Sep 11, 2023
954b97a
.
smilesun Sep 11, 2023
62cafe9
.
smilesun Sep 11, 2023
f8f7a65
.
smilesun Sep 11, 2023
aaf3aeb
.
smilesun Sep 11, 2023
d8db3fd
.
smilesun Sep 11, 2023
25675c9
.
smilesun Sep 11, 2023
07f24de
.
smilesun Sep 11, 2023
0ce810b
Update test_fbopt.py
smilesun Sep 11, 2023
78bec02
Merge branch 'master' into fbopt
smilesun Sep 11, 2023
867b006
.
smilesun Sep 11, 2023
7cd92d2
add optimizer reset
smilesun Sep 12, 2023
e431e3c
mu iter start
smilesun Sep 12, 2023
3512acd
Merge branch 'fbopt' of github.com:marrlab/DomainLab into fbopt
smilesun Sep 12, 2023
d7d8214
no early stop
smilesun Sep 12, 2023
3461fa9
return false in flag_stop
smilesun Sep 12, 2023
f1d6fcf
.
smilesun Sep 12, 2023
b4eaf69
Merge remote-tracking branch 'marrlab/xd_fix_bub_logger' into fbopt
smilesun Sep 12, 2023
bca6f3f
hyperparameters for fbopt
smilesun Sep 12, 2023
959b9a2
forgot to add args to main
smilesun Sep 12, 2023
8343b1f
multipler with np.power
smilesun Sep 12, 2023
36d2e36
.
smilesun Sep 12, 2023
17bc89e
run schell
smilesun Sep 12, 2023
5a85025
eval loss with torch_no grad
smilesun Sep 12, 2023
a716e71
detailed log
smilesun Sep 13, 2023
ad0be7c
no drop for loader
smilesun Sep 13, 2023
9b921fd
same loss function has big variation
smilesun Sep 13, 2023
a35c50d
print success rate
smilesun Sep 13, 2023
c91e919
reduce repetitive computation
smilesun Sep 13, 2023
8b9b625
.
smilesun Sep 13, 2023
cfaeef6
.
smilesun Sep 13, 2023
e76bb5a
demo file to debug why loss changes
smilesun Sep 13, 2023
56c85ce
Merge branch 'xd_refactor_reg_loss_separate_multipler' into fbopt_reg
smilesun Sep 13, 2023
891dbeb
Merge branch 'fbopt' into fbopt_reg
smilesun Sep 14, 2023
daac36d
code refinement, mu iteration start at 1
smilesun Sep 14, 2023
c03e639
correct reg loss
smilesun Sep 14, 2023
b04aa58
Merge pull request #368 from marrlab/fbopt_reg
smilesun Sep 14, 2023
0dcb3cc
fixed randomness
smilesun Sep 14, 2023
3a1a403
add yaml for fbopt
lisab00 Sep 14, 2023
5eeda3b
Merge branch 'fbopt' into lb_benchm_fbopt
lisab00 Sep 14, 2023
e87df4f
fix mu iteration
smilesun Sep 14, 2023
a428d79
big mu seems to lead to finding of reg-descent operator
smilesun Sep 14, 2023
b59629a
.
smilesun Sep 14, 2023
e742e32
reorganize code
smilesun Sep 14, 2023
2a0c8fb
separate theta and theta bar
smilesun Sep 14, 2023
b0c7e8a
increase beta_mu
lisab00 Sep 14, 2023
3e0123a
Merge branch 'fbopt' into lb_benchm_fbopt
lisab00 Sep 14, 2023
e95e9f0
works
smilesun Sep 14, 2023
4538fff
free descent operator
smilesun Sep 14, 2023
e66faa8
.
smilesun Sep 14, 2023
af0f1e6
Merge remote-tracking branch 'origin/fbopt' into lb_benchm_fbopt
Sep 15, 2023
676c44f
Merge branch 'master' into fbopt
smilesun Sep 15, 2023
4048286
Merge remote-tracking branch 'origin/fbopt' into lb_benchm_fbopt
Sep 15, 2023
b4b008a
Merge branch 'fbopt' into lb_benchm_fbopt
lisab00 Sep 15, 2023
f84047b
Merge branch 'lb_benchm_fbopt' of https://github.com/marrlab/DomainLa…
Sep 15, 2023
71698f4
add datatype
lisab00 Sep 15, 2023
b8d0dfa
Merge branch 'master' into fbopt
smilesun Sep 18, 2023
d8e1358
Merge branch 'fbopt' into lb_benchm_fbopt
smilesun Sep 18, 2023
33f2cd3
Merge branch 'fbopt' into fbopt_safe
smilesun Sep 18, 2023
9f5aa78
merge
smilesun Sep 18, 2023
61bb942
Merge pull request #377 from marrlab/fbopt_safe
smilesun Sep 18, 2023
1b4ead5
.
smilesun Sep 18, 2023
b4c3235
.
smilesun Sep 18, 2023
878f1fd
Merge branch 'fbopt' of github.com:marrlab/DomainLab into fbopt
smilesun Sep 18, 2023
420d94a
.
smilesun Sep 18, 2023
e6701be
.
smilesun Sep 18, 2023
a4d69a4
Merge branch 'fbopt' of github.com:marrlab/DomainLab into fbopt
smilesun Sep 18, 2023
596fae5
.
smilesun Sep 18, 2023
e942d50
Merge branch 'fbopt' of github.com:marrlab/DomainLab into fbopt
smilesun Sep 18, 2023
d6ad31b
.
smilesun Sep 18, 2023
242ecf2
.
smilesun Sep 18, 2023
aa514c7
tensorboard ready
smilesun Sep 19, 2023
a15a721
moving average
ntchen Sep 19, 2023
d61684a
moving average
ntchen Sep 19, 2023
f09a9ce
.
smilesun Sep 19, 2023
7ffdce0
Merge branch 'fbopt_alternate' of github.com:marrlab/DomainLab into f…
smilesun Sep 19, 2023
6ba72c2
moving average runable, loss being supressed
smilesun Sep 19, 2023
2db9cc9
clip
ntchen Sep 19, 2023
b5d6720
clip works for mu
smilesun Sep 19, 2023
e8de85b
fix bug clip and report
smilesun Sep 19, 2023
1ee4680
.
smilesun Sep 19, 2023
b99ffa6
refine code, rename
smilesun Sep 19, 2023
e406cc3
.
smilesun Sep 19, 2023
8ec956a
.
smilesun Sep 19, 2023
38a8a22
.
smilesun Sep 19, 2023
db548ca
automatic initialization of epsilon as setpoint
smilesun Sep 19, 2023
abd88e2
.
smilesun Sep 19, 2023
6f0ac6e
rename
smilesun Sep 20, 2023
1a5c12b
Merge branch 'master' into fbopt
smilesun Sep 20, 2023
eb200c4
Merge branch 'fbopt' into fbopt_alternate
smilesun Sep 20, 2023
eba818f
Delete demo_draw_loss.py
smilesun Sep 20, 2023
e6c3b6f
unit test
smilesun Sep 20, 2023
dbcae6b
Merge branch 'xd_refactor_scheduler' into fbopt
smilesun Sep 20, 2023
fb663dd
Merge branch 'fbopt' into fbopt_alternate
smilesun Sep 20, 2023
17d3872
Merge pull request #385 from marrlab/fbopt_alternate
smilesun Sep 20, 2023
394549d
diva
smilesun Sep 20, 2023
043f75d
setpoint ratio hyper
smilesun Sep 20, 2023
b5117e9
.
smilesun Sep 20, 2023
55cd4e4
Merge branch 'master' into fbopt
smilesun Sep 20, 2023
d8301d6
.
smilesun Sep 20, 2023
9cb53b8
.
smilesun Sep 20, 2023
14b46f6
.
smilesun Sep 20, 2023
938f570
.
smilesun Sep 20, 2023
265b3ca
.
smilesun Sep 20, 2023
920c05d
.
smilesun Sep 20, 2023
dba0ac5
Merge branch 'master' into fbopt
smilesun Sep 20, 2023
5cbd1b5
Merge branch 'fbopt' into fbopt_vec
smilesun Sep 20, 2023
affa7a0
Merge branch 'fbopt' into fbopt_setpoint_ada
smilesun Sep 20, 2023
08f1ba0
moving average of setpoint switch
smilesun Sep 20, 2023
9096b39
Merge pull request #396 from marrlab/fbopt_setpoint_ada
smilesun Sep 20, 2023
9f7a126
Merge branch 'fbopt' into fbopt_vec
smilesun Sep 20, 2023
551e7d1
need gain to be a dictionary
smilesun Sep 20, 2023
98bc188
.
smilesun Sep 20, 2023
bb765b3
.
smilesun Sep 20, 2023
57bde44
.
smilesun Sep 21, 2023
1d1dbed
Merge remote-tracking branch 'marrlab/xd_fix_bug_diva_hduva_loss_zero…
smilesun Sep 21, 2023
f98b82f
ready
smilesun Sep 21, 2023
c1ca511
runs without error
smilesun Sep 21, 2023
cdf90c0
add taskloss to tensorboard plots
Sep 21, 2023
97c6e1c
Merge branch 'master' into cf_plot_l_for_bopt
Car-la-F Sep 21, 2023
9623a58
Merge pull request #408 from marrlab/cf_plot_l_for_bopt
smilesun Sep 21, 2023
e9fa887
Merge pull request #397 from marrlab/fbopt_vec
smilesun Sep 21, 2023
532568c
Merge branch 'fbopt' into fbopt_pareto_front_slide
smilesun Sep 21, 2023
f7f1ad1
Merge branch 'fbopt' into fbopt_tensorboard_twoline
smilesun Sep 21, 2023
6693bb1
Merge branch 'fbopt' of https://github.com/marrlab/DomainLab into fbopt
drEast Sep 21, 2023
4fdf025
Merge branch 'fbopt_tensorboard_twoline' of https://github.com/marrla…
drEast Sep 21, 2023
5a25336
Merge branch 'fbopt' into lb_benchm_fbopt
smilesun Sep 21, 2023
db08d24
Merge pull request #375 from marrlab/lb_benchm_fbopt
smilesun Sep 21, 2023
4e5baed
Merge branch 'fbopt' into fbopt_tensorboard_twoline
smilesun Sep 21, 2023
872a025
Merge branch 'master' into fbopt
smilesun Sep 21, 2023
00b813c
Merge branch 'fbopt' into fbopt_tensorboard_twoline
smilesun Sep 21, 2023
82cd554
Tensorboard plotting: reg/dyn_i vs reg/setpoint_i
drEast Sep 21, 2023
9c280c8
Tensorboard logging improved dual line design
drEast Sep 21, 2023
01ae0e5
Merge branch 'fbopt' into fbopt_pareto_front_slide
smilesun Sep 21, 2023
6175d9a
Tensorflow Plotting: reg/dyn vs task plot
drEast Sep 21, 2023
7659386
Merge branch 'fbopt_tensorboard_twoline' of https://github.com/marrla…
drEast Sep 21, 2023
9553f4f
Merge pull request #398 from marrlab/fbopt_tensorboard_twoline
drEast Sep 21, 2023
9de9c2d
Merge branch 'master' into fbopt
smilesun Sep 21, 2023
61dd6a3
Hotfix: But #411 - List multiplication with scalar for update setpoint
drEast Sep 21, 2023
7d71982
Merge pull request #413 from marrlab/fbopt_hotfixes
drEast Sep 21, 2023
d83501e
Added activation clip
drEast Sep 22, 2023
d292207
Adapt Readme installable on cluster
drEast Sep 22, 2023
67431d5
Merge branch 'fbopt' into fbopt_adaptReadme
smilesun Sep 22, 2023
d9f3248
Merge pull request #420 from marrlab/fbopt_adaptReadme
smilesun Sep 22, 2023
64f2ce3
Merge branch 'master' into fbopt
smilesun Sep 22, 2023
3b99621
Merge pull request #421 from marrlab/fbopt_clipActivation
smilesun Sep 22, 2023
75ebf70
Merge branch 'fbopt' into fbopt_pareto_front_slide
smilesun Sep 22, 2023
0f3c8b4
clean
smilesun Sep 22, 2023
342b343
Clearified x and y label in task vs reg plots
drEast Sep 22, 2023
fac9616
Plot all mmu instead of only primary
drEast Sep 22, 2023
ef4e821
hyper parameter
smilesun Sep 22, 2023
ae1c373
Merge pull request #399 from marrlab/fbopt_pareto_front_slide
smilesun Sep 22, 2023
d7dd05f
Added penalized loss
drEast Sep 22, 2023
03d8985
rename setpint
smilesun Sep 22, 2023
83aa0df
Merge branch 'master' into fbopt
smilesun Sep 22, 2023
981fe3e
Added penalized loss
drEast Sep 22, 2023
ae4fd7e
first eeidiotn
smilesun Sep 22, 2023
6e7daed
working yaml file
smilesun Sep 22, 2023
def1e67
working
smilesun Sep 22, 2023
16e810a
Merge pull request #428 from marrlab/fbopt_example_yaml
smilesun Sep 22, 2023
32b6a51
.
smilesun Sep 22, 2023
f5047a5
.
smilesun Sep 22, 2023
6282b98
pacs yaml works
smilesun Sep 22, 2023
b710191
mnist benchmark works for fbopt
smilesun Sep 22, 2023
c8bd353
.
smilesun Sep 22, 2023
ee4e356
fix bug in yaml file
smilesun Sep 22, 2023
9db9445
.
smilesun Sep 22, 2023
ee0e12b
argument to disable tensorboard
Sep 22, 2023
c253e4d
Merge branch 'master' into fbopt
smilesun Sep 22, 2023
ac9474c
disabled tensorboard in fbopt benchmark
e-dorigatti Sep 22, 2023
71fea39
Merge branch 'fbopt' into fbopt_disable_tensorboard
e-dorigatti Sep 22, 2023
0e9a368
Merge pull request #431 from marrlab/fbopt_disable_tensorboard
smilesun Sep 22, 2023
ac287ca
change model selection back to validation accuracy
smilesun Sep 22, 2023
8567d74
Update arg_parser.py to have msel include "last"
smilesun Sep 22, 2023
fcba664
using slurm job id to differentiate log folders
e-dorigatti Sep 22, 2023
14c5dfe
Merge pull request #434 from marrlab/fbopt_logging_on_slurm
smilesun Sep 22, 2023
bb5c51b
Merge branch 'fbopt' into xd_fbopt_msel_hyper
smilesun Sep 22, 2023
d3bfa02
Update c_msel_bang.py
smilesun Sep 22, 2023
73c707a
Update train_fbopt.py
smilesun Sep 22, 2023
984cd0e
Update arg_parser.py
smilesun Sep 22, 2023
55678ab
Update train_fbopt.py
smilesun Sep 22, 2023
0ed622c
indentation
e-dorigatti Sep 22, 2023
e8cb8c0
Update train_fbopt.py
smilesun Sep 22, 2023
fcd9add
Update train_fbopt.py
smilesun Sep 22, 2023
bd12c17
Update train_fbopt.py
smilesun Sep 22, 2023
b9ac0bf
Update train_fbopt.py
smilesun Sep 22, 2023
8af4ab4
Update train_fbopt.py
smilesun Sep 22, 2023
104268a
Update c_obvisitor_cleanup.py
smilesun Sep 22, 2023
a5c92ec
Update c_obvisitor_cleanup.py
smilesun Sep 22, 2023
aaeb13b
Merge pull request #433 from marrlab/xd_fbopt_msel_hyper
smilesun Sep 22, 2023
e1f8b81
Update benchmark_fbopt_pacs_full.yaml
smilesun Sep 22, 2023
1f8c51d
Update benchmark_fbopt_pacs_full.yaml
smilesun Sep 22, 2023
402291f
new hyperparamter ranges
e-dorigatti Sep 22, 2023
44fee66
Merge pull request #437 from marrlab/fbopt_new_hparams
smilesun Sep 22, 2023
442a768
Update model_hduva.py
smilesun Sep 22, 2023
512a913
Create run_fbopt_hduva
smilesun Sep 22, 2023
55c25c5
Update train_fbopt.py
smilesun Sep 22, 2023
593c344
Create run_fbopt_dann.sh
smilesun Sep 22, 2023
1f6d355
.
smilesun Sep 23, 2023
ae34747
specify values for gamma_y in fbopt benchmark
e-dorigatti Sep 23, 2023
584476a
added torch and tensorboard dependencies
e-dorigatti Sep 23, 2023
ea2a3d9
removed tensorboard from readme as its in the requirements already
e-dorigatti Sep 23, 2023
345b6c1
Merge pull request #447 from marrlab/fbopt_fix_ci
smilesun Sep 23, 2023
a28eeb3
Merge pull request #446 from marrlab/fbopt_benchmark_typeerror_440
smilesun Sep 23, 2023
6c922a7
Update README.md
smilesun Sep 23, 2023
cb3f877
saving gamma_d in diva
e-dorigatti Sep 23, 2023
5f73e82
Merge branch 'fbopt' into fbopt_fix_diva
e-dorigatti Sep 23, 2023
355409d
added gamma_d as hyperparameter
e-dorigatti Sep 23, 2023
567453b
dynamically adjusting gamma_d
e-dorigatti Sep 23, 2023
fed7b0e
using gamma_y in hduva
e-dorigatti Sep 23, 2023
0f71c10
Merge pull request #448 from marrlab/fbopt_fix_diva
smilesun Sep 23, 2023
4db485b
use multiplier before p_loss, which won't change task loss
smilesun Sep 23, 2023
105f7a4
update script
smilesun Sep 23, 2023
8ed78f8
update script
smilesun Sep 23, 2023
6999ab7
Merge branch 'master' into fbopt
smilesun Sep 23, 2023
845ce03
fixed wrong setting for msel_tr_loss
e-dorigatti Sep 23, 2023
8533b5b
printing actual values of all parameters before running experiment
e-dorigatti Sep 23, 2023
d0777e8
Merge pull request #453 from marrlab/fbopt_fix_benchmark
smilesun Sep 23, 2023
3031451
Merge pull request #455 from marrlab/fbopt_better_debug
smilesun Sep 23, 2023
fd12ef7
merge conflict
smilesun Sep 23, 2023
c42531e
Merge pull request #426 from marrlab/fbopt_plotPenalized
drEast Sep 23, 2023
9b47eef
.
smilesun Sep 23, 2023
2276ec9
Merge branch 'master' into fbopt_master
smilesun Sep 23, 2023
1f24982
update yaml
smilesun Sep 23, 2023
78d586c
.
smilesun Sep 23, 2023
9fe261e
Merge branch 'fbopt' into fbopt_recon_multiplier
smilesun Sep 23, 2023
7d45139
.
smilesun Sep 23, 2023
45512dd
add register
smilesun Sep 23, 2023
98ade9b
Merge branch 'fbopt_recon_multiplier' of github.com:marrlab/DomainLab…
smilesun Sep 23, 2023
11ab785
add register mu component
smilesun Sep 23, 2023
1921f51
hduva
smilesun Sep 23, 2023
0a274d9
cpu
smilesun Sep 23, 2023
dca3c68
Merge pull request #459 from marrlab/fbopt_recon_multiplier
smilesun Sep 23, 2023
029d481
Merge branch 'fbopt' into fbopt_master
smilesun Sep 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions domainlab/algos/trainers/fbopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
update hyper-parameters during training
"""
import copy
from domainlab.utils.logger import Logger


class HyperSchedulerFeedback():
"""
design $\\mu$$ sequence based on state of penalized loss
"""
def __init__(self, trainer, **kwargs):
"""
kwargs is a dictionary with key the hyper-parameter name and its value
"""
self.trainer = trainer
self.mmu = kwargs
self.mmu = {key: 0.0 for key, val in self.mmu.items()}
self.ploss_old_theta_old_mu = None
self.ploss_old_theta_new_mu = None
self.ploss_new_theta_old_mu = None
self.ploss_new_theta_new_mu = None
self.delta_mu = 0.01 # FIXME
self.dict_theta = None
self.budget_mu_per_step = 5 # FIXME
self.budget_theta_update_per_mu = 5 # np.infty

def search_mu(self, dict_theta):
"""
start from parameter dict_theta,
enlarge mmu to see if the criteria is met
"""
self.dict_theta = dict_theta
smilesun marked this conversation as resolved.
Show resolved Hide resolved
mmu = None
for miter in range(self.budget_mu_per_step):
# FIXME: the same mu is tried two times since miter=0
mmu = self.dict_addition(self.mmu, miter * self.delta_mu)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miter always start from 0, but let miter=0 only make sense for the 1st iteration

print(f"trying mu={mmu} at mu iteration {miter}")
if self.search_theta(mmu):
print(f"!!!found reg-pareto operator with mu={mmu}")
self.mmu = mmu
return True
logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO")
logger.warn(f"!!!!!!failed to find mu within budget, mu={mmu}")
return False

def dict_addition(self, dict_base, delta):
"""
increase the value of a dictionary by delta
"""
return {key: val + delta for key, val in dict_base.items()}

def search_theta(self, mmu_new):
"""
conditioned on fixed $$\\mu$$, the operator should search theta based on
the current value of $theta$

the execution will set the value for mu and theta as well
"""
flag_success = False
self.ploss_old_theta_new_mu = self.trainer.eval_loss(mmu_new, self.dict_theta)
self.ploss_old_theta_old_mu = self.trainer.eval_loss(self.mmu, self.dict_theta)
theta4mu_new = copy.deepcopy(self.dict_theta)
for i in range(self.budget_theta_update_per_mu):
print(f"update theta at iteration {i} with mu={mmu_new}")
theta4mu_new = self.trainer.opt_theta(mmu_new, theta4mu_new)
self.ploss_new_theta_new_mu = self.trainer.eval_loss(mmu_new, theta4mu_new)
self.ploss_new_theta_old_mu = self.trainer.eval_loss(self.mmu, theta4mu_new)
if self.is_criteria_met():
self.mmu = mmu_new
flag_success = True
# FIXME: update theta only if current mu is good enough?
self.dict_theta = theta4mu_new
return flag_success
return flag_success

def inner_product(self, mmu, v_reg_loss):
"""
- the first dimension of the tensor v_reg_loss is mini-batch
the second dimension is the number of regularizers
- the vector mmu has dimension the number of regularizers
"""
return mmu * v_reg_loss #

def is_criteria_met(self):
"""
if the reg-descent criteria is met
"""
flag_improve = self.ploss_new_theta_new_mu < self.ploss_old_theta_new_mu
flag_deteriorate = self.ploss_new_theta_old_mu > self.ploss_old_theta_old_mu
return flag_improve & flag_deteriorate

def __call__(self, epoch):
"""
"""
99 changes: 99 additions & 0 deletions domainlab/algos/trainers/train_fbopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
feedback optimization
"""
import copy

from domainlab.algos.trainers.a_trainer import AbstractTrainer
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.algos.trainers.fbopt import HyperSchedulerFeedback


class HyperSetter():
"""
mock object to force hyper-parameter in the model
"""
def __init__(self, dict_hyper):
self.dict_hyper = dict_hyper

def __call__(self, epoch=None):
return self.dict_hyper


class TrainerFbOpt(AbstractTrainer):
"""
feedback optimization
"""
def set_scheduler(self, scheduler=HyperSchedulerFeedback):
"""
Args:
scheduler: The class name of the scheduler, the object corresponding to
this class name will be created inside model
"""
# model.hyper_init will register the hyper-parameters of the model to scheduler
self.hyper_scheduler = self.model.hyper_init(scheduler, trainer=self)

def before_tr(self):
"""
before training begins, construct helper objects
"""
self.set_scheduler(scheduler=HyperSchedulerFeedback)
self.model.evaluate(self.loader_te, self.device)
self.inner_trainer = TrainerBasic() # look ahead
# here we need a mechanism to generate deep copy of the model
self.inner_trainer.init_business(
copy.deepcopy(self.model), self.task, self.observer, self.device, self.aconf,
flag_accept=False)

def opt_theta(self, dict4mu, dict_theta0):
"""
operator for theta, move gradient for one epoch, then check if criteria is met
this method will be invoked by the hyper-parameter scheduling object
"""
self.inner_trainer.model.set_params(dict_theta0)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model.train

# mock the model hyper-parameter to be from dict4mu
self.inner_trainer.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(dict4mu))
# hide implementation details of inner_trainer
for _, (tensor_x, vec_y, vec_d, *others) in enumerate(self.inner_trainer.loader_tr):
self.inner_trainer.train_batch(tensor_x, vec_y, vec_d, others) # update inner_net
dict_par = dict(self.inner_trainer.model.named_parameters())
return dict_par

def eval_loss(self, dict4mu, dict_theta):
"""
evaluate the penalty function value
"""
temp_model = copy.deepcopy(self.model)
# mock the model hyper-parameter to be from dict4mu
temp_model.hyper_update(epoch=None, fun_scheduler=HyperSetter(dict4mu))
temp_model.set_params(dict_theta)
epo_reg_loss = 0
# FIXME: check if reg is decreasing
epo_task_loss = 0
epo_p_loss = 0 # penalized loss
# FIXME: will loader be corupted? if called at different places? if we do not make deep copy
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr):
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
b_reg_loss = temp_model.cal_reg_loss(tensor_x, vec_y, vec_d).sum()
b_task_loss = temp_model.cal_task_loss(tensor_x, vec_y).sum()
# sum will kill the dimension of the mini batch
b_p_loss = temp_model.cal_loss(tensor_x, vec_y, vec_d).sum()
epo_reg_loss += b_reg_loss
epo_task_loss += b_task_loss
epo_p_loss += b_p_loss
return epo_p_loss

def tr_epoch(self, epoch):
self.model.train()
flag_success = self.hyper_scheduler.search_mu(
dict(self.model.named_parameters())) # if mu not found, will terminate
if flag_success:
# only in success case, mu will be updated
self.model.set_params(self.hyper_scheduler.dict_theta)
else:
# if failed to find reg-pareto descent operator, continue training
theta = dict(self.model.named_parameters())
dict_par = self.opt_theta(self.hyper_scheduler.mmu, copy.deepcopy(theta))
self.model.set_params(dict_par)
flag_stop = self.observer.update(epoch) # FIXME: should count how many epochs were used
return flag_stop
1 change: 1 addition & 0 deletions domainlab/algos/trainers/train_hyper_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def set_scheduler(self, scheduler, total_steps,
flag_update_batch: if hyper-parameters should be changed per batch
"""
self.hyper_scheduler = self.model.hyper_init(scheduler)
# let model register its hyper-parameters to the scheduler
self.flag_update_hyper_per_epoch = flag_update_epoch
self.flag_update_hyper_per_batch = flag_update_batch
self.hyper_scheduler.set_steps(total_steps=total_steps)
Expand Down
4 changes: 3 additions & 1 deletion domainlab/algos/trainers/zoo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from domainlab.algos.trainers.train_matchdg import TrainerMatchDG
from domainlab.algos.trainers.train_mldg import TrainerMLDG
from domainlab.algos.trainers.train_hyper_scheduler import TrainerHyperScheduler
from domainlab.algos.trainers.train_fbopt import TrainerFbOpt


class TrainerChainNodeGetter(object):
Expand Down Expand Up @@ -38,6 +39,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
chain = TrainerDIAL(chain)
chain = TrainerMatchDG(chain)
chain = TrainerMLDG(chain)
chain = TrainerHyperScheduler(chain) # FIXME: change to warmup
chain = TrainerHyperScheduler(chain)
chain = TrainerFbOpt(chain)
node = chain.handle(self.request)
return node
7 changes: 7 additions & 0 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ class AModel(nn.Module, metaclass=abc.ABCMeta):
"""
operations that all models (classification, segmentation, seq2seq)
"""
def set_params(self, dict_params):
"""
set
"""
# FIXME: net1.load_state_dict(net2.state_dict()) contains more information than model.named_parameters() like optimizer status
self.load_state_dict(dict_params, strict=False)

def cal_loss(self, tensor_x, tensor_y, tensor_d=None, others=None):
"""
calculate the loss
Expand Down
6 changes: 3 additions & 3 deletions domainlab/models/model_dann.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def mk_dann(parent_class=AModelClassif):
The model is trained to solve two tasks:
1. Standard image classification.
2. Domain classification.
Here for, a feature extractor is adversarially trained to minimize the loss of the image
Here for, a feature extractor is adversarially trained to minimize the loss of the image
classifier and maximize the loss of the domain classifier.
For more details, see:
Ganin, Yaroslav, et al. "Domain-adversarial training of neural networks."
Expand Down Expand Up @@ -66,11 +66,11 @@ def hyper_update(self, epoch, fun_scheduler):
dict_rst = fun_scheduler(epoch) # the __call__ method of hyperparameter scheduler
self.alpha = dict_rst["alpha"]

def hyper_init(self, functor_scheduler):
def hyper_init(self, functor_scheduler, trainer=None):
"""hyper_init.
:param functor_scheduler:
"""
return functor_scheduler(alpha=self.alpha)
return functor_scheduler(trainer=trainer, alpha=self.alpha)

def cal_logit_y(self, tensor_x): # FIXME: this is only for classification
"""
Expand Down
6 changes: 6 additions & 0 deletions run_fbopt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
# export CUDA_VISIBLE_DEVICES=""
# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error
# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring
# pytest -s tests/test_fbopt.py
python main_out.py --te_d=caltech --task=mini_vlcs --bs=16 --aname=dann --trainer=fbopt --nname=alexnet --epos=20
6 changes: 6 additions & 0 deletions run_fbopt_pacs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
# export CUDA_VISIBLE_DEVICES=""
# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error
# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring
# pytest -s tests/test_fbopt.py
python main_out.py --te_d=sketch --tpath=examples/tasks/task_pacs_path_list.py --bs=4 --aname=dann --trainer=fbopt --nname=alexnet --epos=20
12 changes: 12 additions & 0 deletions tests/test_fbopt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
unit and end-end test for deep all, mldg
"""
from tests.utils_test import utils_test_algo


def test_deepall_fbopt():
"""
train DeepAll with MLDG
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --aname=dann --trainer=fbopt --nname=alexnet --epos=3"
utils_test_algo(args)
Loading