diff --git a/docs/doc_miro.md b/docs/doc_miro.md new file mode 100644 index 000000000..89a8402ff --- /dev/null +++ b/docs/doc_miro.md @@ -0,0 +1,70 @@ +# MIRO: Mutual-Information Regularization +## Mutual Information Regularization with Oracle (MIRO). + +### Pre-requisite: Variational lower bound on mutual information + +Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16, no. 320 (2004): 201. + +$$I(X,Y)=H(Y)-H(Y|X)=-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log p(Y|X)\rangle}_{p(X,Y)}$$ + +Given variational distribution of $q(x|y)$ as decoder (i.e. $Y$ encodes information from $X$) + + +Since + +$$KL\left(p(X|Y)|q(X|Y)\right)={\langle\log p(X|Y)\rangle}_{p(X|Y)}-\langle{\log q(X,Y)\rangle}_{p(X|Y)} >0$$ + +We have + +$${\langle\log p(X|Y)\rangle}_{p(X|Y)}>{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$ + +Then + +$$I(X,Y)=-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log p(Y|X)\rangle}_{p(X,Y)}>-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$ + +with the lower bound being + +$$-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$ + +To optimize the lower bound, one can iterate + +- fix decoder $q(X|Y)$ and optimize encoder $Y=g(X;\theta) + \epsilon$ +- fix encoder parameter $\theta$, tune decoder to alleviate the lower bound + +#### Laplace approximation + +decoding posterior: + +$$p(X|Y) \sim Gaussian(Y|[\Sigma^{-1}]_{ij}=\frac{\partial^2 \log p(x|y)}{\partial x_i\partial x_j})$$ + +when $|Y|$ is large (large deviation from zero contains more information, which must be explained by non-typical $X$) + +#### Linear Gaussian + + +The bound $H(X)+{\langle\log q(X|Y)\rangle}_{p(x,y)}$ becomes + +$$\sum_i {\langle|X_i-m(Y_i)|_{|\Sigma^{-1}|(Y_i)} + \log det(\Sigma(Y_i))\rangle}_{p(Y|X)}$$ + + +## MIRO + +MIRO try to match the pre-trained model's features layer by layer to the target neural network we want to train for domain invariance in terms of mutual information. They use a constant identity encoder on feature from target neural network, then a population variance $\Sigma$ (forced to be diagonal). + +Let $z$ denote the intermediate features of each layer, let $f_0$ be the pre-trained model, $f$ be the target neural network. Let $x$ be the input data. + +$$z_f=f(x)$$ + +$$z_{f_0}=f^{(0)}(x)$$ + +the lower bound for Mutual information for instance $i$ is + + +$$\log|\Sigma| + ||z^{(i)}_{f_0}-id(z^{i})||_{\Sigma}^{-1}$$ + +where $id$ is the mean map + +For diagonal $\Sigma$, determinant is simply multiplication of all diagonal values, + +$$\log|\Sigma|=\sum_{k} \log \sigma_k + ||{z_k}^{(i)}_{f_0}-z_k^{i}||{\sigma_k}^{-1}$$ + diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index db5c70c6a..a484812d4 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -91,6 +91,8 @@ def __init__(self, successor_node=None, extend=None): self._ma_iter = 0 # self.list_reg_over_task_ratio = None + # MIRO + self.input_tensor_shape = None @property def model(self): @@ -203,6 +205,7 @@ def cal_reg_loss_over_task_loss_ratio(self): for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate( self.loader_tr ): + self.input_tensor_shape = tensor_x.shape if ind_batch >= self.aconf.nb4reg_over_task_ratio: return tensor_x, tensor_y, tensor_d = ( diff --git a/domainlab/algos/trainers/args_miro.py b/domainlab/algos/trainers/args_miro.py new file mode 100644 index 000000000..188ac2db1 --- /dev/null +++ b/domainlab/algos/trainers/args_miro.py @@ -0,0 +1,17 @@ +""" +miro trainer configurations +""" + + +def add_args2parser_miro(parser): + """ + append hyper-parameters to the main argparser + """ + arg_group_miro = parser.add_argument_group("miro") + arg_group_miro.add_argument( + "--layers2extract_feats", + nargs="*", + default=None, + help="layer names separated by space to extract features", + ) + return parser diff --git a/domainlab/algos/trainers/train_miro.py b/domainlab/algos/trainers/train_miro.py new file mode 100644 index 000000000..657fac807 --- /dev/null +++ b/domainlab/algos/trainers/train_miro.py @@ -0,0 +1,63 @@ +""" +author: Kakao Brain. +# https://arxiv.org/pdf/2203.10789#page=3.77 +# [aut] xudong, alexej +""" + +import torch +from torch import nn +from domainlab.algos.trainers.train_basic import TrainerBasic +from domainlab.algos.trainers.train_miro_utils import \ + MeanEncoder, VarianceEncoder +from domainlab.algos.trainers.train_miro_model_wraper import \ + TrainerMiroModelWraper + + +class TrainerMiro(TrainerBasic): + """Mutual-Information Regularization with Oracle""" + def before_tr(self): + self.model_wraper = TrainerMiroModelWraper() + self.model_wraper.accept(self.model, + name_feat_layers2extract=self.aconf.layers2extract_feats) + self.mean_encoders = None + self.var_encoders = None + super().before_tr() + shapes = self.model_wraper.get_shapes(self.input_tensor_shape) + self.mean_encoders = nn.ModuleList([ + MeanEncoder(shape) for shape in shapes + ]) + self.var_encoders = nn.ModuleList([ + VarianceEncoder(shape) for shape in shapes + ]) + + def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): + # list_batch_inter_feat_new are features for each layer + list_batch_inter_feat_new = \ + self.model_wraper.extract_intermediate_features( + tensor_x, tensor_y, tensor_d, others) + + # reference model + with torch.no_grad(): + list_batch_inter_feat_ref = self.model_wraper.cal_feat_layers_ref_model( + tensor_x, tensor_y, tensor_d, others) + # dim(list_batch_inter_feat_ref)=[size_batch, dim_feat] + if self.mean_encoders is None: + device = tensor_x.device + return [torch.zeros(tensor_x.shape[0]).to(device)], [self.aconf.gamma_reg] + + reg_loss = 0. + num_layers = len(self.mean_encoders) + device = tensor_x.device + for ind_layer in range(num_layers): + # layerwise mutual information regularization + mean_encoder = self.mean_encoders[ind_layer].to(device) + feat = list_batch_inter_feat_new[ind_layer] + feat = feat.to(device) + mean = mean_encoder(feat) + var_encoder = self.var_encoders[ind_layer].to(device) + var = var_encoder(feat) + mean_ref = list_batch_inter_feat_ref[ind_layer] + mean_ref = mean_ref.to(device) + vlb = (mean - mean_ref).pow(2).div(var) + var.log() + reg_loss += vlb.mean(dim=tuple(range(1, vlb.dim()))) / 2. + return [reg_loss], [self.aconf.gamma_reg] diff --git a/domainlab/algos/trainers/train_miro_model_wraper.py b/domainlab/algos/trainers/train_miro_model_wraper.py new file mode 100644 index 000000000..2b0994b01 --- /dev/null +++ b/domainlab/algos/trainers/train_miro_model_wraper.py @@ -0,0 +1,72 @@ +""" +https://arxiv.org/pdf/2203.10789#page=3.77 +""" +import copy +import torch +from torch import nn + + +class TrainerMiroModelWraper(): + """Mutual-Information Regularization with Oracle""" + def __init__(self): + self._features = [] + self._features_ref = [] + self.guest_model = None + self.ref_model = None + self.flag_module_found = False + + def get_shapes(self, input_shape): + # get shape of intermediate features + self.clear_features() + with torch.no_grad(): + dummy = torch.rand(*input_shape).to(next(self.guest_model.parameters()).device) + self.guest_model(dummy) + shapes = [feat.shape for feat in self._features] + return shapes + + def accept(self, guest_model, name_feat_layers2extract=None): + self.guest_model = guest_model + self.ref_model = copy.deepcopy(guest_model) + self.register_feature_storage_hook(name_feat_layers2extract) + + def register_feature_storage_hook(self, feat_layers=None): + # memorize features for each layer in self._feautres list + if feat_layers is None: + module = list(self.guest_model.children())[-1] + module.register_forward_hook(self.hook) + module_ref = list(self.ref_model.children())[-1] + module_ref.register_forward_hook(self.hook_ref) + else: + for name, module in self.guest_model.named_modules(): + if name in feat_layers: + module.register_forward_hook(self.hook) + self.flag_module_found = True + + if not self.flag_module_found: + raise RuntimeError(f"{feat_layers} not found in model!") + + for name, module in self.ref_model.named_modules(): + if name in feat_layers: + module.register_forward_hook(self.hook_ref) + + def hook(self, module, input, output): + self._features.append(output.detach()) + + def hook_ref(self, module, input, output): + self._features_ref.append(output.detach()) + + def extract_intermediate_features(self, tensor_x, tensor_y, tensor_d, others=None): + """ + extract features for each layer of the neural network + """ + self.clear_features() + self.guest_model(tensor_x) + return self._features + + def clear_features(self): + self._features.clear() + + def cal_feat_layers_ref_model(self, tensor_x, tensor_y, tensor_d, others=None): + self._features_ref.clear() + self.ref_model(tensor_x) + return self._features_ref diff --git a/domainlab/algos/trainers/train_miro_utils.py b/domainlab/algos/trainers/train_miro_utils.py new file mode 100644 index 000000000..c876ca82f --- /dev/null +++ b/domainlab/algos/trainers/train_miro_utils.py @@ -0,0 +1,34 @@ +""" +Laplace approximation for Mutual Information estimation +""" +import torch +import torch.nn.functional as F +from torch import nn + + +class MeanEncoder(nn.Module): + """Identity function""" + def __init__(self, inter_layer_feat_shape): + super().__init__() + self.inter_layer_feat_shape = inter_layer_feat_shape + + def forward(self, x): + return x + + +class VarianceEncoder(nn.Module): + """Bias-only model with diagonal covariance""" + def __init__(self, inter_layer_feat_shape, init=0.1, eps=1e-5): + super().__init__() + self.inter_layer_feat_shape = inter_layer_feat_shape + self.eps = eps + + init = (torch.as_tensor(init - eps).exp() - 1.0).log() + b_shape = inter_layer_feat_shape + self.b = nn.Parameter(torch.full(b_shape, init)) + + def forward(self, feat_layer_tensor_batch): + """ + train batch(population) level variance + """ + return F.softplus(self.b) + self.eps diff --git a/domainlab/algos/trainers/zoo_trainer.py b/domainlab/algos/trainers/zoo_trainer.py index 2d5529738..c11049375 100644 --- a/domainlab/algos/trainers/zoo_trainer.py +++ b/domainlab/algos/trainers/zoo_trainer.py @@ -12,6 +12,7 @@ from domainlab.algos.trainers.train_irm import TrainerIRM from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL from domainlab.algos.trainers.train_coral import TrainerCoral +from domainlab.algos.trainers.train_miro import TrainerMiro class TrainerChainNodeGetter(object): @@ -59,6 +60,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None): chain = TrainerHyperScheduler(chain) chain = TrainerCausalIRL(chain) chain = TrainerCoral(chain) + chain = TrainerMiro(chain) node = chain.handle(self.request) head = node while self._list_str_trainer: diff --git a/domainlab/arg_parser.py b/domainlab/arg_parser.py index cf9028cda..d456e3045 100644 --- a/domainlab/arg_parser.py +++ b/domainlab/arg_parser.py @@ -8,6 +8,7 @@ from domainlab.algos.trainers.args_dial import add_args2parser_dial from domainlab.algos.trainers.compos.matchdg_args import add_args2parser_matchdg +from domainlab.algos.trainers.args_miro import add_args2parser_miro from domainlab.models.args_jigen import add_args2parser_jigen from domainlab.models.args_vae import add_args2parser_vae from domainlab.utils.logger import Logger @@ -356,6 +357,8 @@ def mk_parser_main(): arg_group_vae = add_args2parser_vae(arg_group_vae) arg_group_matchdg = parser.add_argument_group("matchdg") arg_group_matchdg = add_args2parser_matchdg(arg_group_matchdg) + arg_group_miro = parser.add_argument_group("miro") + arg_group_miro = add_args2parser_miro(arg_group_miro) arg_group_jigen = parser.add_argument_group("jigen") arg_group_jigen = add_args2parser_jigen(arg_group_jigen) args_group_dial = parser.add_argument_group("dial") diff --git a/domainlab/models/a_model.py b/domainlab/models/a_model.py index 71b0db334..113b8c456 100644 --- a/domainlab/models/a_model.py +++ b/domainlab/models/a_model.py @@ -103,14 +103,15 @@ def _extend_loss(self, tensor_x, tensor_y, tensor_d, others=None): return self._decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others) return None, None - def forward(self, tensor_x, tensor_y, tensor_d, others=None): + def forward(self, tensor_x): """forward. :param x: :param y: :param d: """ - return self.cal_loss(tensor_x, tensor_y, tensor_d, others) + out = self.extract_semantic_feat(tensor_x) + return out def extract_semantic_feat(self, tensor_x): """ @@ -205,7 +206,7 @@ def name(self): def print_parameters(self): """ - Function to print all parameters of the object. + Function to print all parameters of the object. Can be used to print the parameters of every child class. """ params = vars(self) diff --git a/tests/test_miro.py b/tests/test_miro.py new file mode 100644 index 000000000..ca4b3f51c --- /dev/null +++ b/tests/test_miro.py @@ -0,0 +1,13 @@ +""" +end-end test for mutual information regulation +""" +from tests.utils_test import utils_test_algo + + +def test_miro(): + """ + train with MIRO + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=miro --nname=alexnet" + utils_test_algo(args) diff --git a/tests/test_miro2.py b/tests/test_miro2.py new file mode 100644 index 000000000..4fba341d7 --- /dev/null +++ b/tests/test_miro2.py @@ -0,0 +1,14 @@ +""" +end-end test for mutual information regulation +""" +from tests.utils_test import utils_test_algo + + +def test_miro2(): + """ + train with MIRO + """ + args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \ + --trainer=miro --nname=conv_bn_pool_2 \ + --layers2extract_feats _net_invar_feat.conv_net.5" + utils_test_algo(args) diff --git a/tests/test_miro3.py b/tests/test_miro3.py new file mode 100644 index 000000000..75341b403 --- /dev/null +++ b/tests/test_miro3.py @@ -0,0 +1,17 @@ +""" +end-end test for mutual information regulation +""" +import pytest +from tests.utils_test import utils_test_algo + + +def test_miro3(): + """ + train with MIRO + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=miro --nname=alexnet \ + --layers2extract_feats features" + with pytest.raises(RuntimeError): + utils_test_algo(args) + raise RuntimeError("This is a runtime error") diff --git a/tests/test_model_diva.py b/tests/test_model_diva.py index 15ad609c7..3ea4b0e86 100644 --- a/tests/test_model_diva.py +++ b/tests/test_model_diva.py @@ -32,4 +32,4 @@ def test_model_diva(): ) imgs, y_s, d_s = mk_rand_xyd(28, y_dim, 2, 2) _, _, _, _, _ = model.infer_y_vpicn(imgs) - model(imgs, y_s, d_s) + model.cal_loss(imgs, y_s, d_s)