Skip to content

Commit

Permalink
Merge branch 'fbopt_output_ma' into fbopt_setpoint_ma_output_ma
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Oct 10, 2023
2 parents b4f364c + 5cdc3c5 commit e748524
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 5 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ pip install -r requirements_notorch.txt
conda install tensorboard
```

#### Download PACS

step 1:

use the following script to download PACS to your local laptop and upload it to your cluster

https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py

step 2:
make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh

where `mkdir -p data/pacs` is executed under the repository directory,

`ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS`
will create a symbolic link under the repository directory


#### Windows installation details

To install DomainLab on Windows, please remove the `snakemake` dependency from the `requirements.txt` file.
Expand Down
4 changes: 4 additions & 0 deletions domainlab/algos/trainers/args_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def add_args2parser_fbopt(parser):

parser.add_argument('--coeff_ma', type=float, default=0.5,
help='exponential moving average')

parser.add_argument('--coeff_ma_output_state', type=float, default=0.5,
help='setpoint output as state exponential moving average')


parser.add_argument('--exp_shoulder_clip', type=float, default=10,
help='clip before exponential operation')
Expand Down
16 changes: 12 additions & 4 deletions domainlab/algos/trainers/fbopt_setpoint_ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import torch
from domainlab.utils.logger import Logger

def list_add(list1, list2):
return [a + b for a, b in zip(list1, list2)]

def list_multiply(list1, coeff):
return [ele * coeff for ele in list1]

def is_less_list_any(list1, list2):
"""
Expand Down Expand Up @@ -36,9 +41,10 @@ def __init__(self, state=None, args=None):
state = DominateAnyComponent()
self.transition_to(state)
self.ma_epo_reg_loss = None
self.state_epo_reg_loss = None
self.coeff_ma = 0.5 # FIXME
self.state_task_loss = None
self.state_task_loss = 0.0
self.state_epo_reg_loss = [0.0 for _ in range(10)] # FIXME
self.coeff_ma_output = args.coeff_ma_output_state
# initial value will be set via trainer
self.setpoint4R = None
self.setpoint4ell = None
Expand All @@ -63,8 +69,10 @@ def observe(self, epo_reg_loss, epo_task_loss):
read current epo_reg_loss continuously
FIXME: setpoint should also be able to be eliviated
"""
self.state_epo_reg_loss = epo_reg_loss
self.state_task_loss = epo_task_loss
self.state_epo_reg_loss = [self.coeff_ma_output*a + ( 1-self.coeff_ma_output )*b if a != 0.0 else b for a, b in zip(self.state_epo_reg_loss, epo_reg_loss)]
if self.state_task_loss == 0.0:
self.state_task_loss = epo_task_loss
self.state_task_loss = self.coeff_ma_output * self.state_task_loss + (1-self.coeff_ma_output) * epo_task_loss
if self.state_updater.update_setpoint():
logger = Logger.get_logger(logger_name='main_out_logger', loglevel="INFO")
self.update_setpoint_ma(self.state_epo_reg_loss)
Expand Down
3 changes: 2 additions & 1 deletion run_fbopt_mnist.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
# although garbage collector has been explicitly called, sometimes there is still CUDA out of memory error
# so it is better not to use GPU to do the pytest to ensure every time there is no CUDA out of memory error occuring
# pytest -s tests/test_fbopt.py
python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --es=50 --mu_init=0.00001

python main_out.py --te_d=1 --tr_d 0 3 --task=mnistcolor10 --bs=16 --aname=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=2000 --es=50 --mu_init=0.00001 --coeff_ma_output_state=0.5

0 comments on commit e748524

Please sign in to comment.