diff --git a/docs/docFishr.md b/docs/docFishr.md
new file mode 100644
index 000000000..3d88cef07
--- /dev/null
+++ b/docs/docFishr.md
@@ -0,0 +1,80 @@
+# Fishr: Invariant Gradient Variances for Out-of-distribution Generalization
+
+The goal of the Fishr regularization technique is locally aligning the domain-level loss landscapes
+around the final weights, finding a minimizer around which the inconsistencies between
+the domain-level loss landscapes are as small as possible.
+This is done by considering second order terms during training, matching
+the variances between the domain-level gradients.
+
+
+
+
Figure 1: Fishr matches the domain-level gradient variances of the
+distributions across the training domains (Image source: Figure 1 of "Fishr:
+Invariant gradient variances for out-of-distribution generalization")
+
+
+
+
+### Quantifying inconsistency between domains
+Intuitively, two domains are locally inconsistent around a minimizer, if a small
+perturbation of the minimizer highly affects its optimality in one domain, but only
+minimally affects its optimality in the other domain. Under certain assumptions, most importantly
+the Hessians being positive definite, it is possible to measure the inconsistency between two domains
+$$A$$ and $$B$$ with the following inconsistency score:
+
+$$
+\mathcal{I}^\epsilon (\theta^*) = \textnormal{max}_{(A,B)\in\mathcal{E}^2} \biggl( \mathcal{R}_B(\theta^*) - \mathcal{R}_A(\theta^*) + \textnormal{max}_{\frac{1}{2}\theta^T H_A \theta\leq\epsilon}\frac{1}{2}\theta^T H_B \theta \biggl)
+$$
+
+, whereby $$\theta^*$$ denotes the minimizer, $$\mathcal{E}$$ denotes the set of training domains,
+$$H_e$$ denotes the Hessian for $$e\in\mathcal{E}$$, ยงยง\theta$$ denote the network parameters
+and $$\mathcal{R}_e$$ for $$e\in\mathcal{E}$$ denotes the domain-level ERM objective.
+The Fishr regularization method forces both terms on the right hand side
+of the inconsistency score to become small. The first term represents the difference
+between the domain-level risks and is implicitly forced to be small by applying
+the Fishr regularization. For the second term it suffices to align diagonal approximations of the
+domain-level Hessians, matching the variances across domains.
+
+
+
+
+### Matching the Variances during training
+Let $$\mathcal{E}$$ be the space of all training domains, and let $$\mathcal{R}_e(\theta) be the ERM
+objective. Fishr minimizes the following objective function during training:
+
+$$
+\mathcal{L}(\theta) = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} \mathcal{R}_\mathcal{E}(\theta) + \lambda \mathcal{L}_{\textnormal{Fishr}}(\theta)
+$$
+
+, whereby
+
+$$
+\mathcal{L}_\textnormal{Fishr}(\theta) = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} \| v_e -v \|^2_2
+$$
+
+with $$v_e$$ denoting the variance between the gradients of domain $$e\in\mathcal{E}$$ and
+$$v$$ denoting the average variance of the gradients across all domains, i.e.
+$$ 0):
- # hash will keep integer intact and hash strings to random seed
- # hased integer is signed and usually too big, random seed only
- # allowed to be in [0, 2^32-1]
- # if the user input is number, then hash will not change the value,
- # so we recommend the user to use number as start seed
- if sampling_seed_str.isdigit():
- sampling_seed = int(sampling_seed_str)
- else:
- sampling_seed = abs(hash(sampling_seed_str)) % (2 ** 32)
- elif 'sampling_seed' in config.keys():
- sampling_seed = config['sampling_seed']
- else:
- sampling_seed = None
+ # for gridsearch there is no random component, therefore no
+ # random seed is needed
if 'mode' in config.keys(): # type(config)=dict
if config['mode'] == 'grid':
- sample_gridsearch(config,str(output.dest),sampling_seed)
+ sample_gridsearch(config,str(output.dest))
+ # for random sampling we need to consider a random seed
else:
+ sampling_seed_str = params.sampling_seed
+ if isinstance(sampling_seed_str, str) and (len(sampling_seed_str) > 0):
+ # hash will keep integer intact and hash strings to random seed
+ # hased integer is signed and usually too big, random seed only
+ # allowed to be in [0, 2^32-1]
+ # if the user input is number, then hash will not change the value,
+ # so we recommend the user to use number as start seed
+ if sampling_seed_str.isdigit():
+ sampling_seed = int(sampling_seed_str)
+ else:
+ sampling_seed = abs(hash(sampling_seed_str)) % (2 ** 32)
+ elif 'sampling_seed' in config.keys():
+ sampling_seed = config['sampling_seed']
+ else:
+ sampling_seed = None
+
sample_hyperparameters(config, str(output.dest), sampling_seed)
@@ -123,7 +147,8 @@ rule run_experiment:
# in the resulting pandas dataframe
# :param out_file: path to the output csv
num_gpus = int(num_gpus_str)
- run_experiment(config,str(input.param_file),index,str(output.out_file), start_seed, num_gpus=num_gpus)
+ run_experiment(config, str(input.param_file), index,str(output.out_file),
+ start_seed, num_gpus=num_gpus)
rule agg_results:
diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py
index 635d06839..703011f81 100644
--- a/domainlab/models/model_diva.py
+++ b/domainlab/models/model_diva.py
@@ -10,8 +10,46 @@
def mk_diva(parent_class=VAEXYDClassif):
"""
- DIVA with arbitrary task loss
+ Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss.
+
+ Details:
+ This method is creating a generative model based on a variational autoencoder, which can
+ reconstruct the input images. Here for, three different encoders with latent variables are
+ trained, each representing a latent subspace for the domain, class and residual features
+ information, respectively. The latent subspaces serve for disentangling the respective
+ sources of variation. To reconstruct the input image, the three latent variables are fed
+ into a decoder.
+ Additionally, two classifiers are trained, which predict the domain and the class label.
+ For more details, see:
+ Ilse, Maximilian, et al. "Diva: Domain invariant variational autoencoders."
+ Medical Imaging with Deep Learning. PMLR, 2020.
+
+ Args:
+ parent_class: Class object determining the task type. Defaults to VAEXYDClassif.
+
+ Returns:
+ ModelDIVA: model inheriting from parent class.
+
+ Input Parameters:
+ zd_dim: size of latent space for domain-specific information,
+ zy_dim: size of latent space for class-specific information,
+ zx_dim: size of latent space for residual variance,
+ chain_node_builder: creates the neural network specified by the user; object of the class
+ "VAEChainNodeGetter" (see domainlab/compos/vae/utils_request_chain_builder.py)
+ being initialized by entering a user request,
+ list_str_y: list of labels,
+ list_d_tr: list of training domains,
+ gamma_d: weighting term for d classifier,
+ gamma_y: weighting term for y classifier,
+ beta_d: weighting term for domain encoder,
+ beta_x: weighting term for residual variation encoder,
+ beta_y: weighting term for class encoder
+
+ Usage:
+ For a concrete example, see:
+ https://github.com/marrlab/DomainLab/blob/master/tests/test_mk_exp_diva.py
"""
+
class ModelDIVA(parent_class):
"""
DIVA
diff --git a/domainlab/utils/hyperparameter_gridsearch.py b/domainlab/utils/hyperparameter_gridsearch.py
index ce05f7a40..aecd7a5fd 100644
--- a/domainlab/utils/hyperparameter_gridsearch.py
+++ b/domainlab/utils/hyperparameter_gridsearch.py
@@ -242,8 +242,7 @@ def grid_task(grid_df: pd.DataFrame, task_name: str, config: dict):
def sample_gridsearch(config: dict,
- dest: str = None,
- sampling_seed: int = None) -> pd.DataFrame:
+ dest: str = None) -> pd.DataFrame:
"""
create the hyperparameters grid according to the given
config, which should be the dictionary of the full
@@ -257,9 +256,6 @@ def sample_gridsearch(config: dict,
if dest is None:
dest = config['output_dir'] + os.sep + 'hyperparameters.csv'
- if not sampling_seed is None:
- np.random.seed(sampling_seed)
-
logger = Logger.get_logger()
samples = pd.DataFrame(columns=['task', 'algo', 'params'])
for key, val in config.items():