Skip to content

Commit

Permalink
Merge branch 'fbopt' into fbopt_grad_clip
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Oct 10, 2023
2 parents 435acf6 + 977a86b commit 0d425a4
Show file tree
Hide file tree
Showing 23 changed files with 373 additions and 68 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,32 @@ then

#### Guide for Helmholtz GPU cluster
```
conda create --name domainlab_py39 python=3.9
conda activate domainlab_py39
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.6 -c pytorch -c conda-forge
conda install torchmetric==0.10.3
git checkout fbopt
pip install -r requirements_notorch.txt
conda install tensorboard
```

#### Download PACS

step 1:

use the following script to download PACS to your local laptop and upload it to your cluster

https://github.com/marrlab/DomainLab/blob/fbopt/data/script/download_pacs.py

step 2:
make a symbolic link following the example script in https://github.com/marrlab/DomainLab/blob/master/sh_pacs.sh

where `mkdir -p data/pacs` is executed under the repository directory,

`ln -s /dir/to/yourdata/pacs/raw ./data/pacs/PACS`
will create a symbolic link under the repository directory


#### Windows installation details

To install DomainLab on Windows, please remove the `snakemake` dependency from the `requirements.txt` file.
Expand Down
2 changes: 2 additions & 0 deletions ci_run_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ sed -n '/```shell/,/```/ p' docs/doc_examples.md | sed '/^```/ d' >> ./sh_temp_e
bash -x -v -e sh_temp_example.sh
echo "general examples done"

rm -r zoutput

echo "#!/bin/bash -x -v" > sh_temp_mnist.sh
sed -n '/```shell/,/```/ p' docs/doc_MNIST_classification.md | sed '/^```/ d' >> ./sh_temp_mnist.sh
bash -x -v -e sh_temp_mnist.sh
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/builder_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def init_business(self, exp):
request = RequestVAEBuilderCHW(
task.isize.c, task.isize.h, task.isize.w, args)
node = VAEChainNodeGetter(request)()
model = mk_diva()(node,
model = mk_diva(str_mu=args.str_mu)(node,
zd_dim=args.zd_dim,
zy_dim=args.zy_dim,
zx_dim=args.zx_dim,
Expand Down
2 changes: 1 addition & 1 deletion domainlab/algos/msels/a_model_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def accept(self, trainer, tr_obs):
self.tr_obs = tr_obs

@abc.abstractmethod
def update(self):
def update(self, clear_counter=False):
"""
observer + visitor pattern to trainer
if the best model should be updated
Expand Down
4 changes: 2 additions & 2 deletions domainlab/algos/msels/c_msel_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, msel=None):
self.best_oracle_acc = 0
self.msel = msel

def update(self):
def update(self, clear_counter=False):
"""
if the best model should be updated
"""
Expand All @@ -35,7 +35,7 @@ def update(self):
logger.info("new oracle model saved")
flag = True
if self.msel is not None:
return self.msel.update()
return self.msel.update(clear_counter)
return flag

def if_stop(self):
Expand Down
5 changes: 4 additions & 1 deletion domainlab/algos/msels/c_msel_tr_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, max_es):
self.max_es = max_es
super().__init__()

def update(self):
def update(self, clear_counter=False):
"""
if the best model should be updated
"""
Expand All @@ -34,6 +34,9 @@ def update(self):
logger.info(f"early stop counter: {self.es_c}")
logger.info(f"loss:{loss}, best loss: {self.best_loss}")
flag = False # do not update best model
if clear_counter:
logger.info("clearing counter")
self.es_c = 0
return flag

def if_stop(self):
Expand Down
10 changes: 6 additions & 4 deletions domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def __init__(self, max_es):
self.best_te_metric = 0.0
super().__init__(max_es) # construct self.tr_obs (observer)

def update(self):
def update(self, clear_counter=False):
"""
if the best model should be updated
"""
flag = True
if self.tr_obs.metric_val is None or self.tr_obs.str_msel == "loss_tr":
return super().update()
return super().update(clear_counter)
metric = self.tr_obs.metric_val[self.tr_obs.str_metric4msel]
if self.tr_obs.metric_te is not None:
metric_te_current = self.tr_obs.metric_te[self.tr_obs.str_metric4msel]
self.best_te_metric = max(self.best_te_metric, metric_te_current)

if metric > self.best_val_acc: # observer
if metric > self.best_val_acc: # update hat{model}
# different from loss, accuracy should be improved: the bigger the better
self.best_val_acc = metric
self.es_c = 0 # restore counter
Expand All @@ -45,5 +45,7 @@ def update(self):
f"corresponding to test acc: \
{self.sel_model_te_acc} / {self.best_te_metric}")
flag = False # do not update best model

if clear_counter:
logger.info("clearing counter")
self.es_c = 0
return flag
3 changes: 3 additions & 0 deletions domainlab/algos/trainers/args_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def add_args2parser_fbopt(parser):
parser.add_argument('--no_setpoint_update', action='store_true', default=False,
help='disable setpoint update')

parser.add_argument('--str_mu', type=str, default="default", help='which penalty to tune')


# the following hyperparamters do not need to be tuned
parser.add_argument('--beta_mu', type=float, default=1.1,
help='how much to multiply mu each time')
Expand Down
5 changes: 1 addition & 4 deletions domainlab/algos/trainers/fbopt_alternate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def cal_delta4control(self, list1, list_setpoint):
def cal_delta_integration(self, list_old, list_new, coeff):
return [(1-coeff)*a + coeff*b for a, b in zip(list_old, list_new)]

def search_mu(self, epo_reg_loss, epo_task_loss, dict_theta=None, miter=None):
def search_mu(self, epo_reg_loss, epo_task_loss, epo_loss_tr, dict_theta=None, miter=None):
"""
start from parameter dictionary dict_theta: {"layer":tensor},
enlarge mu w.r.t. its current value
Expand Down Expand Up @@ -137,9 +137,6 @@ def search_mu(self, epo_reg_loss, epo_task_loss, dict_theta=None, miter=None):
f'reg/setpoint{i}': reg_set,
}, miter)
self.writer.add_scalar(f'x-axis=task vs y-axis=reg/dyn{i}', reg_dyn, epo_task_loss)

epo_loss_tr = epo_task_loss + torch.inner(
torch.Tensor(list(self.mmu.values())), torch.Tensor(epo_reg_loss))
self.writer.add_scalar('loss_penalized', epo_loss_tr, miter)
self.writer.add_scalar('task', epo_task_loss, miter)
acc_te = 0
Expand Down
12 changes: 8 additions & 4 deletions domainlab/algos/trainers/train_mu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,27 @@ def eval_r_loss(self):
# mock the model hyper-parameter to be from dict4mu
epo_reg_loss = []
epo_task_loss = 0
epo_p_loss = 0
counter = 0.0
with torch.no_grad():
for _, (tensor_x, vec_y, vec_d, *_) in enumerate(self.loader_tr_no_drop):
tensor_x, vec_y, vec_d = \
tensor_x.to(self.device), vec_y.to(self.device), vec_d.to(self.device)
tuple_reg_loss = self.model.cal_reg_loss(tensor_x, vec_y, vec_d)
p_loss, *_ = self.model.cal_loss(tensor_x, vec_y, vec_d)
# NOTE: first [0] extract the loss, second [0] get the list
list_b_reg_loss = tuple_reg_loss[0]
list_b_reg_loss_sumed = [ele.sum().item() for ele in list_b_reg_loss]
list_b_reg_loss_sumed = [ele.sum().detach().item() for ele in list_b_reg_loss]
if len(epo_reg_loss) == 0:
epo_reg_loss = list_b_reg_loss_sumed
else:
epo_reg_loss = list(map(add, epo_reg_loss, list_b_reg_loss_sumed))
b_task_loss = self.model.cal_task_loss(tensor_x, vec_y).sum()
b_task_loss = self.model.cal_task_loss(tensor_x, vec_y).sum().detach().item()
# sum will kill the dimension of the mini batch
epo_task_loss += b_task_loss
epo_p_loss += p_loss.sum().detach().item()
counter += 1.0
return list_divide(epo_reg_loss, counter), epo_task_loss/counter
return list_divide(epo_reg_loss, counter), epo_task_loss/counter, epo_p_loss / counter

def before_batch(self, epoch, ind_batch):
"""
Expand All @@ -77,7 +80,7 @@ def before_batch(self, epoch, ind_batch):
def before_tr(self):
self.set_scheduler(scheduler=HyperSchedulerFeedbackAlternave)
self.model.hyper_update(epoch=None, fun_scheduler=HyperSetter(self.hyper_scheduler.mmu))
self.epo_reg_loss_tr, self.epo_task_loss_tr = self.eval_r_loss()
self.epo_reg_loss_tr, self.epo_task_loss_tr, self.epo_loss_tr = self.eval_r_loss()
self.hyper_scheduler.set_setpoint(
[ele * self.aconf.ini_setpoint_ratio for ele in self.epo_reg_loss_tr],
self.epo_task_loss_tr)
Expand All @@ -90,6 +93,7 @@ def tr_epoch(self, epoch):
self.hyper_scheduler.search_mu(
self.epo_reg_loss_tr,
self.epo_task_loss_tr,
self.epo_loss_tr,
dict(self.model.named_parameters()),
miter=epoch)
self.hyper_scheduler.update_setpoint(self.epo_reg_loss_tr, self.epo_task_loss_tr)
Expand Down
4 changes: 3 additions & 1 deletion domainlab/models/a_model_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,6 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
"""
for ERM to adapt to the interface of other regularized learners
"""
return [torch.Tensor([0])], [0.0]
device = tensor_x.device
bsize = tensor_x.shape[0]
return [torch.zeros(bsize, 1).to(device)], [0.0]
75 changes: 69 additions & 6 deletions domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from domainlab.utils.utils_class import store_args


def mk_diva(parent_class=VAEXYDClassif):
def mk_diva(parent_class=VAEXYDClassif, str_mu="default"):
"""
Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss.
Expand Down Expand Up @@ -89,8 +89,6 @@ def hyper_update(self, epoch, fun_scheduler):
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.gamma_d = dict_rst["gamma_d"]
self.mu_recon = dict_rst["mu_recon"]

def hyper_init(self, functor_scheduler, trainer=None):
"""
Expand All @@ -100,11 +98,9 @@ def hyper_init(self, functor_scheduler, trainer=None):
"""
return functor_scheduler(
trainer=trainer,
mu_recon=self.mu_recon,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
gamma_d=self.gamma_d,
)

def get_list_str_y(self):
Expand Down Expand Up @@ -142,4 +138,71 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
lc_d = F.cross_entropy(logit_d, d_target, reduction="none")
return [loss_recon_x, zd_p_minus_zd_q, zx_p_minus_zx_q, zy_p_minus_zy_q, lc_d], \
[self.mu_recon, -self.beta_d, -self.beta_x, -self.beta_y, -self.gamma_d]
return ModelDIVA

class ModelDIVAGammadRecon(ModelDIVA):
def hyper_update(self, epoch, fun_scheduler):
"""hyper_update.
:param epoch:
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.gamma_d = dict_rst["gamma_d"]
self.mu_recon = dict_rst["mu_recon"]

def hyper_init(self, functor_scheduler, trainer=None):
"""
initiate a scheduler object via class name and things inside this model
:param functor_scheduler: the class name of the scheduler
"""
return functor_scheduler(
trainer=trainer,
mu_recon=self.mu_recon,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
gamma_d=self.gamma_d,
)


class ModelDIVAGammad(ModelDIVA):
def hyper_update(self, epoch, fun_scheduler):
"""hyper_update.
:param epoch:
:param fun_scheduler:
"""
dict_rst = fun_scheduler(epoch)
self.beta_d = dict_rst["beta_d"]
self.beta_y = dict_rst["beta_y"]
self.beta_x = dict_rst["beta_x"]
self.gamma_d = dict_rst["gamma_d"]

def hyper_init(self, functor_scheduler, trainer=None):
"""
initiate a scheduler object via class name and things inside this model
:param functor_scheduler: the class name of the scheduler
"""
return functor_scheduler(
trainer=trainer,
beta_d=self.beta_d,
beta_y=self.beta_y,
beta_x=self.beta_x,
gamma_d=self.gamma_d,
)

class ModelDIVADefault(ModelDIVA):
"""
"""
if str_mu == "gammad_recon":
return ModelDIVAGammadRecon
if str_mu == "gammad":
return ModelDIVAGammad
if str_mu == "default":
return ModelDIVADefault
raise RuntimeError("not support argument candiates for str_mu: allowed: default, gammad_recon, gammad")
63 changes: 63 additions & 0 deletions domainlab/utils/generate_fbopt_phase_portrait.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import glob
import os

import matplotlib.pyplot as plt
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator


# FIXME: maybe adjust the output path where the png is saved
output_dir = "../.."

def get_xy_from_event_file(event_file, tf_size_guidance=None):
if tf_size_guidance is None:
# settings for which/how much data is loaded from the tensorboard event files
tf_size_guidance = {
'compressedHistograms': 0,
'images': 0,
'scalars': 1e10, # keep unlimited number
'histograms': 0
}
# load event file
event = EventAccumulator(event_file, tf_size_guidance)
event.Reload()
# extract the reg/dyn0 values
y_event = event.Scalars('x-axis=task vs y-axis=reg/dyn0')
y = [s.value for s in y_event]
x_int = [s.step for s in y_event] # the .step data are saved as ints in tensorboard, so we will re-extact from 'task'
# extract the corresponding 'task' values
x_event = event.Scalars('task')
x = [s.value for s in x_event]
# sanity check:
for i in range(len(x)):
assert int(x[i]) == x_int[i]
return x, y

def phase_portrain_combined(event_files, colors):
plt.figure()

for event_i in range(len(event_files)):
x, y = get_xy_from_event_file(event_files[event_i])

assert len(x) == len(y)
for i in range(len(x)-1):
plt.arrow(x[i], y[i], (x[i+1]-x[i]), (y[i+1]-y[i]),
head_width=0.2, head_length=0.2, length_includes_head=True,
fc=colors[event_i], ec=colors[event_i], alpha=0.4)

plt.plot(x[0], y[0], 'ko')
plt.scatter(x, y, s=1, c='black')

plt.xlabel("task")
plt.ylabel("reg/dyn0")
plt.title("x-axis=task vs y-axis=reg/dyn0")

plt.savefig(os.path.join(output_dir, 'phase_portrain_combined.png'), dpi=300)


if __name__ == "__main__":
event_files = glob.glob("../../runs/*/events*")
print("Using the following tensorboard event files:\n{}".format("\n".join(event_files)))
cmap = plt.get_cmap('tab10') # Choose a colormap
colors = [cmap(i) for i in range(len(event_files))] # Different colors for the different runs
phase_portrain_combined(event_files, colors)

Loading

0 comments on commit 0d425a4

Please sign in to comment.