Skip to content

Commit

Permalink
Merge branch 'master' into xd_loader_te_none
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Sep 14, 2023
2 parents 22162a0 + d1acb38 commit 4af4375
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 38 deletions.
80 changes: 80 additions & 0 deletions docs/docFishr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Fishr: Invariant Gradient Variances for Out-of-distribution Generalization

The goal of the Fishr regularization technique is locally aligning the domain-level loss landscapes
around the final weights, finding a minimizer around which the inconsistencies between
the domain-level loss landscapes are as small as possible.
This is done by considering second order terms during training, matching
the variances between the domain-level gradients.

<div style="align: center; text-align:center;">
<img src="figs/fishr.png" style="width:450px;"/>
<div class="caption">Figure 1: Fishr matches the domain-level gradient variances of the
distributions across the training domains (Image source: Figure 1 of "Fishr:
Invariant gradient variances for out-of-distribution generalization") </div>
</div>



### Quantifying inconsistency between domains
Intuitively, two domains are locally inconsistent around a minimizer, if a small
perturbation of the minimizer highly affects its optimality in one domain, but only
minimally affects its optimality in the other domain. Under certain assumptions, most importantly
the Hessians being positive definite, it is possible to measure the inconsistency between two domains
$$A$$ and $$B$$ with the following inconsistency score:

$$
\mathcal{I}^\epsilon (\theta^*) = \textnormal{max}_{(A,B)\in\mathcal{E}^2} \biggl( \mathcal{R}_B(\theta^*) - \mathcal{R}_A(\theta^*) + \textnormal{max}_{\frac{1}{2}\theta^T H_A \theta\leq\epsilon}\frac{1}{2}\theta^T H_B \theta \biggl)
$$

, whereby $$\theta^*$$ denotes the minimizer, $$\mathcal{E}$$ denotes the set of training domains,
$$H_e$$ denotes the Hessian for $$e\in\mathcal{E}$$, §§\theta$$ denote the network parameters
and $$\mathcal{R}_e$$ for $$e\in\mathcal{E}$$ denotes the domain-level ERM objective.
The Fishr regularization method forces both terms on the right hand side
of the inconsistency score to become small. The first term represents the difference
between the domain-level risks and is implicitly forced to be small by applying
the Fishr regularization. For the second term it suffices to align diagonal approximations of the
domain-level Hessians, matching the variances across domains.




### Matching the Variances during training
Let $$\mathcal{E}$$ be the space of all training domains, and let $$\mathcal{R}_e(\theta) be the ERM
objective. Fishr minimizes the following objective function during training:

$$
\mathcal{L}(\theta) = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} \mathcal{R}_\mathcal{E}(\theta) + \lambda \mathcal{L}_{\textnormal{Fishr}}(\theta)
$$

, whereby

$$
\mathcal{L}_\textnormal{Fishr}(\theta) = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} \| v_e -v \|^2_2
$$

with $$v_e$$ denoting the variance between the gradients of domain $$e\in\mathcal{E}$$ and
$$v$$ denoting the average variance of the gradients across all domains, i.e.
$$<v = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} v_e$$.





### Implementation
The variance of the gradients within each domain can be computed with the
BACKPACK package (see: Dangel, Felix, Frederik Kunstner, and Philipp Hennig.
"Backpack: Packing more into backprop." https://arxiv.org/abs/1912.10985).
Further on, we use $$ \textnormal{Var}(G) \approx \textnormal{diag}(H) $$.
The Hessian is then approximated by the Fisher Information Matrix, which
again is approximated by an empirical estimator for computational efficiency.
For more details, see the reference below or the domainlab code.






_Reference:_
Rame, Alexandre, Corentin Dancette, and Matthieu Cord. "Fishr:
Invariant gradient variances for out-of-distribution generalization."
International Conference on Machine Learning. PMLR, 2022.
Binary file added docs/figs/fishr.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def update(self):
else:
self.es_c += 1
logger = Logger.get_logger()
logger.debug("early stop counter: ", self.es_c)
logger.debug(f"early stop counter: {self.es_c}")
logger.debug(f"val acc:{self.tr_obs.metric_te['acc']}, "
f"best validation acc: {self.best_val_acc}")
flag = False # do not update best model
Expand Down
2 changes: 1 addition & 1 deletion domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def mk_parser_main():
arg_group_task.add_argument('--san_num', type=int, default=8,
help='number of images to be dumped for the sanity check')

arg_group_task.add_argument('--loglevel', type=str, default='INFO',
arg_group_task.add_argument('--loglevel', type=str, default='DEBUG',
help='sets the loglevel of the logger')

# args for variational auto encoder
Expand Down
85 changes: 55 additions & 30 deletions domainlab/exp_protocol/benchmark.smk
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import sys
from pathlib import Path
import pandas as pd


try:
config_path = workflow.configfiles[0]
Expand All @@ -24,20 +26,38 @@ def experiment_result_files(_):
"""Lists all expected i.csv"""
from domainlab.utils.hyperparameter_sampling import is_dict_with_key
from domainlab.utils.logger import Logger
# count tasks
num_sample_tasks = 0
num_nonsample_tasks = 0
for key, val in config.items():
if is_dict_with_key(val, "aname"):
if 'hyperparameters' in val.keys():
num_sample_tasks += 1
else:
num_nonsample_tasks += 1
# total number of hyperparameter samples
total_num_params = config['num_param_samples'] * num_sample_tasks + num_nonsample_tasks
from domainlab.utils.hyperparameter_gridsearch import \
sample_gridsearch

logger = Logger.get_logger()
logger.info(f"total_num_params={total_num_params}")
logger.info(f"={config['num_param_samples']} * {num_sample_tasks} + {num_nonsample_tasks}")
if config['mode'] == 'grid':
# hyperparameters are sampled using gridsearch
# in this case we don't know how many samples we will get beforehand
# straigt oreward solution: do a grid sampling and count samples
samples = sample_gridsearch(config)
total_num_params = samples.shape[0]
logger.info(f"total_num_params={total_num_params} for gridsearch")
else:
# in case of random sampling it is possible to compute the number
# of samples from the information in the yaml file

# count tasks
num_sample_tasks = 0
num_nonsample_tasks = 0
for key, val in config.items():
if is_dict_with_key(val, "aname"):
if 'hyperparameters' in val.keys():
num_sample_tasks += 1
else:
if 'shared' in val.keys():
num_sample_tasks += 1
else:
num_nonsample_tasks += 1
# total number of hyperparameter samples
total_num_params = config['num_param_samples'] * num_sample_tasks + num_nonsample_tasks
logger.info(f"total_num_params={total_num_params} for random sampling")
logger.info(f"={config['num_param_samples']} * {num_sample_tasks} + {num_nonsample_tasks}")

return [f"{config['output_dir']}/rule_results/{i}.csv" for i in range(total_num_params)]


Expand All @@ -54,25 +74,29 @@ rule parameter_sampling:
from domainlab.utils.hyperparameter_sampling import sample_hyperparameters
from domainlab.utils.hyperparameter_gridsearch import sample_gridsearch

sampling_seed_str = params.sampling_seed
if isinstance(sampling_seed_str, str) and (len(sampling_seed_str) > 0):
# hash will keep integer intact and hash strings to random seed
# hased integer is signed and usually too big, random seed only
# allowed to be in [0, 2^32-1]
# if the user input is number, then hash will not change the value,
# so we recommend the user to use number as start seed
if sampling_seed_str.isdigit():
sampling_seed = int(sampling_seed_str)
else:
sampling_seed = abs(hash(sampling_seed_str)) % (2 ** 32)
elif 'sampling_seed' in config.keys():
sampling_seed = config['sampling_seed']
else:
sampling_seed = None
# for gridsearch there is no random component, therefore no
# random seed is needed
if 'mode' in config.keys(): # type(config)=dict
if config['mode'] == 'grid':
sample_gridsearch(config,str(output.dest),sampling_seed)
sample_gridsearch(config,str(output.dest))
# for random sampling we need to consider a random seed
else:
sampling_seed_str = params.sampling_seed
if isinstance(sampling_seed_str, str) and (len(sampling_seed_str) > 0):
# hash will keep integer intact and hash strings to random seed
# hased integer is signed and usually too big, random seed only
# allowed to be in [0, 2^32-1]
# if the user input is number, then hash will not change the value,
# so we recommend the user to use number as start seed
if sampling_seed_str.isdigit():
sampling_seed = int(sampling_seed_str)
else:
sampling_seed = abs(hash(sampling_seed_str)) % (2 ** 32)
elif 'sampling_seed' in config.keys():
sampling_seed = config['sampling_seed']
else:
sampling_seed = None

sample_hyperparameters(config, str(output.dest), sampling_seed)


Expand Down Expand Up @@ -123,7 +147,8 @@ rule run_experiment:
# in the resulting pandas dataframe
# :param out_file: path to the output csv
num_gpus = int(num_gpus_str)
run_experiment(config,str(input.param_file),index,str(output.out_file), start_seed, num_gpus=num_gpus)
run_experiment(config, str(input.param_file), index,str(output.out_file),
start_seed, num_gpus=num_gpus)


rule agg_results:
Expand Down
40 changes: 39 additions & 1 deletion domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,46 @@

def mk_diva(parent_class=VAEXYDClassif):
"""
DIVA with arbitrary task loss
Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss.
Details:
This method is creating a generative model based on a variational autoencoder, which can
reconstruct the input images. Here for, three different encoders with latent variables are
trained, each representing a latent subspace for the domain, class and residual features
information, respectively. The latent subspaces serve for disentangling the respective
sources of variation. To reconstruct the input image, the three latent variables are fed
into a decoder.
Additionally, two classifiers are trained, which predict the domain and the class label.
For more details, see:
Ilse, Maximilian, et al. "Diva: Domain invariant variational autoencoders."
Medical Imaging with Deep Learning. PMLR, 2020.
Args:
parent_class: Class object determining the task type. Defaults to VAEXYDClassif.
Returns:
ModelDIVA: model inheriting from parent class.
Input Parameters:
zd_dim: size of latent space for domain-specific information,
zy_dim: size of latent space for class-specific information,
zx_dim: size of latent space for residual variance,
chain_node_builder: creates the neural network specified by the user; object of the class
"VAEChainNodeGetter" (see domainlab/compos/vae/utils_request_chain_builder.py)
being initialized by entering a user request,
list_str_y: list of labels,
list_d_tr: list of training domains,
gamma_d: weighting term for d classifier,
gamma_y: weighting term for y classifier,
beta_d: weighting term for domain encoder,
beta_x: weighting term for residual variation encoder,
beta_y: weighting term for class encoder
Usage:
For a concrete example, see:
https://github.com/marrlab/DomainLab/blob/master/tests/test_mk_exp_diva.py
"""

class ModelDIVA(parent_class):
"""
DIVA
Expand Down
6 changes: 1 addition & 5 deletions domainlab/utils/hyperparameter_gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def grid_task(grid_df: pd.DataFrame, task_name: str, config: dict):


def sample_gridsearch(config: dict,
dest: str = None,
sampling_seed: int = None) -> pd.DataFrame:
dest: str = None) -> pd.DataFrame:
"""
create the hyperparameters grid according to the given
config, which should be the dictionary of the full
Expand All @@ -257,9 +256,6 @@ def sample_gridsearch(config: dict,
if dest is None:
dest = config['output_dir'] + os.sep + 'hyperparameters.csv'

if not sampling_seed is None:
np.random.seed(sampling_seed)

logger = Logger.get_logger()
samples = pd.DataFrame(columns=['task', 'algo', 'params'])
for key, val in config.items():
Expand Down

0 comments on commit 4af4375

Please sign in to comment.