diff --git a/docs/docFishr.md b/docs/docFishr.md new file mode 100644 index 000000000..3d88cef07 --- /dev/null +++ b/docs/docFishr.md @@ -0,0 +1,80 @@ +# Fishr: Invariant Gradient Variances for Out-of-distribution Generalization + +The goal of the Fishr regularization technique is locally aligning the domain-level loss landscapes +around the final weights, finding a minimizer around which the inconsistencies between +the domain-level loss landscapes are as small as possible. +This is done by considering second order terms during training, matching +the variances between the domain-level gradients. + +
+ +
Figure 1: Fishr matches the domain-level gradient variances of the +distributions across the training domains (Image source: Figure 1 of "Fishr: +Invariant gradient variances for out-of-distribution generalization")
+
+ + + +### Quantifying inconsistency between domains +Intuitively, two domains are locally inconsistent around a minimizer, if a small +perturbation of the minimizer highly affects its optimality in one domain, but only +minimally affects its optimality in the other domain. Under certain assumptions, most importantly +the Hessians being positive definite, it is possible to measure the inconsistency between two domains +$$A$$ and $$B$$ with the following inconsistency score: + +$$ +\mathcal{I}^\epsilon (\theta^*) = \textnormal{max}_{(A,B)\in\mathcal{E}^2} \biggl( \mathcal{R}_B(\theta^*) - \mathcal{R}_A(\theta^*) + \textnormal{max}_{\frac{1}{2}\theta^T H_A \theta\leq\epsilon}\frac{1}{2}\theta^T H_B \theta \biggl) +$$ + +, whereby $$\theta^*$$ denotes the minimizer, $$\mathcal{E}$$ denotes the set of training domains, +$$H_e$$ denotes the Hessian for $$e\in\mathcal{E}$$, ยงยง\theta$$ denote the network parameters +and $$\mathcal{R}_e$$ for $$e\in\mathcal{E}$$ denotes the domain-level ERM objective. +The Fishr regularization method forces both terms on the right hand side +of the inconsistency score to become small. The first term represents the difference +between the domain-level risks and is implicitly forced to be small by applying +the Fishr regularization. For the second term it suffices to align diagonal approximations of the +domain-level Hessians, matching the variances across domains. + + + + +### Matching the Variances during training +Let $$\mathcal{E}$$ be the space of all training domains, and let $$\mathcal{R}_e(\theta) be the ERM +objective. Fishr minimizes the following objective function during training: + +$$ +\mathcal{L}(\theta) = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} \mathcal{R}_\mathcal{E}(\theta) + \lambda \mathcal{L}_{\textnormal{Fishr}}(\theta) +$$ + +, whereby + +$$ +\mathcal{L}_\textnormal{Fishr}(\theta) = \frac{1}{|\mathcal{E}|}\sum_{e\in\mathcal{E}} \| v_e -v \|^2_2 +$$ + +with $$v_e$$ denoting the variance between the gradients of domain $$e\in\mathcal{E}$$ and +$$v$$ denoting the average variance of the gradients across all domains, i.e. +$$ 0): - # hash will keep integer intact and hash strings to random seed - # hased integer is signed and usually too big, random seed only - # allowed to be in [0, 2^32-1] - # if the user input is number, then hash will not change the value, - # so we recommend the user to use number as start seed - if sampling_seed_str.isdigit(): - sampling_seed = int(sampling_seed_str) - else: - sampling_seed = abs(hash(sampling_seed_str)) % (2 ** 32) - elif 'sampling_seed' in config.keys(): - sampling_seed = config['sampling_seed'] - else: - sampling_seed = None + # for gridsearch there is no random component, therefore no + # random seed is needed if 'mode' in config.keys(): # type(config)=dict if config['mode'] == 'grid': - sample_gridsearch(config,str(output.dest),sampling_seed) + sample_gridsearch(config,str(output.dest)) + # for random sampling we need to consider a random seed else: + sampling_seed_str = params.sampling_seed + if isinstance(sampling_seed_str, str) and (len(sampling_seed_str) > 0): + # hash will keep integer intact and hash strings to random seed + # hased integer is signed and usually too big, random seed only + # allowed to be in [0, 2^32-1] + # if the user input is number, then hash will not change the value, + # so we recommend the user to use number as start seed + if sampling_seed_str.isdigit(): + sampling_seed = int(sampling_seed_str) + else: + sampling_seed = abs(hash(sampling_seed_str)) % (2 ** 32) + elif 'sampling_seed' in config.keys(): + sampling_seed = config['sampling_seed'] + else: + sampling_seed = None + sample_hyperparameters(config, str(output.dest), sampling_seed) @@ -123,7 +147,8 @@ rule run_experiment: # in the resulting pandas dataframe # :param out_file: path to the output csv num_gpus = int(num_gpus_str) - run_experiment(config,str(input.param_file),index,str(output.out_file), start_seed, num_gpus=num_gpus) + run_experiment(config, str(input.param_file), index,str(output.out_file), + start_seed, num_gpus=num_gpus) rule agg_results: diff --git a/domainlab/models/model_diva.py b/domainlab/models/model_diva.py index 635d06839..703011f81 100644 --- a/domainlab/models/model_diva.py +++ b/domainlab/models/model_diva.py @@ -10,8 +10,46 @@ def mk_diva(parent_class=VAEXYDClassif): """ - DIVA with arbitrary task loss + Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss. + + Details: + This method is creating a generative model based on a variational autoencoder, which can + reconstruct the input images. Here for, three different encoders with latent variables are + trained, each representing a latent subspace for the domain, class and residual features + information, respectively. The latent subspaces serve for disentangling the respective + sources of variation. To reconstruct the input image, the three latent variables are fed + into a decoder. + Additionally, two classifiers are trained, which predict the domain and the class label. + For more details, see: + Ilse, Maximilian, et al. "Diva: Domain invariant variational autoencoders." + Medical Imaging with Deep Learning. PMLR, 2020. + + Args: + parent_class: Class object determining the task type. Defaults to VAEXYDClassif. + + Returns: + ModelDIVA: model inheriting from parent class. + + Input Parameters: + zd_dim: size of latent space for domain-specific information, + zy_dim: size of latent space for class-specific information, + zx_dim: size of latent space for residual variance, + chain_node_builder: creates the neural network specified by the user; object of the class + "VAEChainNodeGetter" (see domainlab/compos/vae/utils_request_chain_builder.py) + being initialized by entering a user request, + list_str_y: list of labels, + list_d_tr: list of training domains, + gamma_d: weighting term for d classifier, + gamma_y: weighting term for y classifier, + beta_d: weighting term for domain encoder, + beta_x: weighting term for residual variation encoder, + beta_y: weighting term for class encoder + + Usage: + For a concrete example, see: + https://github.com/marrlab/DomainLab/blob/master/tests/test_mk_exp_diva.py """ + class ModelDIVA(parent_class): """ DIVA diff --git a/domainlab/utils/hyperparameter_gridsearch.py b/domainlab/utils/hyperparameter_gridsearch.py index ce05f7a40..aecd7a5fd 100644 --- a/domainlab/utils/hyperparameter_gridsearch.py +++ b/domainlab/utils/hyperparameter_gridsearch.py @@ -242,8 +242,7 @@ def grid_task(grid_df: pd.DataFrame, task_name: str, config: dict): def sample_gridsearch(config: dict, - dest: str = None, - sampling_seed: int = None) -> pd.DataFrame: + dest: str = None) -> pd.DataFrame: """ create the hyperparameters grid according to the given config, which should be the dictionary of the full @@ -257,9 +256,6 @@ def sample_gridsearch(config: dict, if dest is None: dest = config['output_dir'] + os.sep + 'hyperparameters.csv' - if not sampling_seed is None: - np.random.seed(sampling_seed) - logger = Logger.get_logger() samples = pd.DataFrame(columns=['task', 'algo', 'params']) for key, val in config.items():