-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #878 from marrlab/miro
- Loading branch information
Showing
13 changed files
with
313 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# MIRO: Mutual-Information Regularization | ||
## Mutual Information Regularization with Oracle (MIRO). | ||
|
||
### Pre-requisite: Variational lower bound on mutual information | ||
|
||
Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16, no. 320 (2004): 201. | ||
|
||
$$I(X,Y)=H(Y)-H(Y|X)=-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log p(Y|X)\rangle}_{p(X,Y)}$$ | ||
|
||
Given variational distribution of $q(x|y)$ as decoder (i.e. $Y$ encodes information from $X$) | ||
|
||
|
||
Since | ||
|
||
$$KL\left(p(X|Y)|q(X|Y)\right)={\langle\log p(X|Y)\rangle}_{p(X|Y)}-\langle{\log q(X,Y)\rangle}_{p(X|Y)} >0$$ | ||
|
||
We have | ||
|
||
$${\langle\log p(X|Y)\rangle}_{p(X|Y)}>{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$ | ||
|
||
Then | ||
|
||
$$I(X,Y)=-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log p(Y|X)\rangle}_{p(X,Y)}>-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$ | ||
|
||
with the lower bound being | ||
|
||
$$-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$ | ||
|
||
To optimize the lower bound, one can iterate | ||
|
||
- fix decoder $q(X|Y)$ and optimize encoder $Y=g(X;\theta) + \epsilon$ | ||
- fix encoder parameter $\theta$, tune decoder to alleviate the lower bound | ||
|
||
#### Laplace approximation | ||
|
||
decoding posterior: | ||
|
||
$$p(X|Y) \sim Gaussian(Y|[\Sigma^{-1}]_{ij}=\frac{\partial^2 \log p(x|y)}{\partial x_i\partial x_j})$$ | ||
|
||
when $|Y|$ is large (large deviation from zero contains more information, which must be explained by non-typical $X$) | ||
|
||
#### Linear Gaussian | ||
|
||
|
||
The bound $H(X)+{\langle\log q(X|Y)\rangle}_{p(x,y)}$ becomes | ||
|
||
$$\sum_i {\langle|X_i-m(Y_i)|_{|\Sigma^{-1}|(Y_i)} + \log det(\Sigma(Y_i))\rangle}_{p(Y|X)}$$ | ||
|
||
|
||
## MIRO | ||
|
||
MIRO try to match the pre-trained model's features layer by layer to the target neural network we want to train for domain invariance in terms of mutual information. They use a constant identity encoder on feature from target neural network, then a population variance $\Sigma$ (forced to be diagonal). | ||
|
||
Let $z$ denote the intermediate features of each layer, let $f_0$ be the pre-trained model, $f$ be the target neural network. Let $x$ be the input data. | ||
|
||
$$z_f=f(x)$$ | ||
|
||
$$z_{f_0}=f^{(0)}(x)$$ | ||
|
||
the lower bound for Mutual information for instance $i$ is | ||
|
||
|
||
$$\log|\Sigma| + ||z^{(i)}_{f_0}-id(z^{i})||_{\Sigma}^{-1}$$ | ||
|
||
where $id$ is the mean map | ||
|
||
For diagonal $\Sigma$, determinant is simply multiplication of all diagonal values, | ||
|
||
$$\log|\Sigma|=\sum_{k} \log \sigma_k + ||{z_k}^{(i)}_{f_0}-z_k^{i}||{\sigma_k}^{-1}$$ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
miro trainer configurations | ||
""" | ||
|
||
|
||
def add_args2parser_miro(parser): | ||
""" | ||
append hyper-parameters to the main argparser | ||
""" | ||
arg_group_miro = parser.add_argument_group("miro") | ||
arg_group_miro.add_argument( | ||
"--layers2extract_feats", | ||
nargs="*", | ||
default=None, | ||
help="layer names separated by space to extract features", | ||
) | ||
return parser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
""" | ||
author: Kakao Brain. | ||
# https://arxiv.org/pdf/2203.10789#page=3.77 | ||
# [aut] xudong, alexej | ||
""" | ||
|
||
import torch | ||
from torch import nn | ||
from domainlab.algos.trainers.train_basic import TrainerBasic | ||
from domainlab.algos.trainers.train_miro_utils import \ | ||
MeanEncoder, VarianceEncoder | ||
from domainlab.algos.trainers.train_miro_model_wraper import \ | ||
TrainerMiroModelWraper | ||
|
||
|
||
class TrainerMiro(TrainerBasic): | ||
"""Mutual-Information Regularization with Oracle""" | ||
def before_tr(self): | ||
self.model_wraper = TrainerMiroModelWraper() | ||
self.model_wraper.accept(self.model, | ||
name_feat_layers2extract=self.aconf.layers2extract_feats) | ||
self.mean_encoders = None | ||
self.var_encoders = None | ||
super().before_tr() | ||
shapes = self.model_wraper.get_shapes(self.input_tensor_shape) | ||
self.mean_encoders = nn.ModuleList([ | ||
MeanEncoder(shape) for shape in shapes | ||
]) | ||
self.var_encoders = nn.ModuleList([ | ||
VarianceEncoder(shape) for shape in shapes | ||
]) | ||
|
||
def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None): | ||
# list_batch_inter_feat_new are features for each layer | ||
list_batch_inter_feat_new = \ | ||
self.model_wraper.extract_intermediate_features( | ||
tensor_x, tensor_y, tensor_d, others) | ||
|
||
# reference model | ||
with torch.no_grad(): | ||
list_batch_inter_feat_ref = self.model_wraper.cal_feat_layers_ref_model( | ||
tensor_x, tensor_y, tensor_d, others) | ||
# dim(list_batch_inter_feat_ref)=[size_batch, dim_feat] | ||
if self.mean_encoders is None: | ||
device = tensor_x.device | ||
return [torch.zeros(tensor_x.shape[0]).to(device)], [self.aconf.gamma_reg] | ||
|
||
reg_loss = 0. | ||
num_layers = len(self.mean_encoders) | ||
device = tensor_x.device | ||
for ind_layer in range(num_layers): | ||
# layerwise mutual information regularization | ||
mean_encoder = self.mean_encoders[ind_layer].to(device) | ||
feat = list_batch_inter_feat_new[ind_layer] | ||
feat = feat.to(device) | ||
mean = mean_encoder(feat) | ||
var_encoder = self.var_encoders[ind_layer].to(device) | ||
var = var_encoder(feat) | ||
mean_ref = list_batch_inter_feat_ref[ind_layer] | ||
mean_ref = mean_ref.to(device) | ||
vlb = (mean - mean_ref).pow(2).div(var) + var.log() | ||
reg_loss += vlb.mean(dim=tuple(range(1, vlb.dim()))) / 2. | ||
return [reg_loss], [self.aconf.gamma_reg] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
""" | ||
https://arxiv.org/pdf/2203.10789#page=3.77 | ||
""" | ||
import copy | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class TrainerMiroModelWraper(): | ||
"""Mutual-Information Regularization with Oracle""" | ||
def __init__(self): | ||
self._features = [] | ||
self._features_ref = [] | ||
self.guest_model = None | ||
self.ref_model = None | ||
self.flag_module_found = False | ||
|
||
def get_shapes(self, input_shape): | ||
# get shape of intermediate features | ||
self.clear_features() | ||
with torch.no_grad(): | ||
dummy = torch.rand(*input_shape).to(next(self.guest_model.parameters()).device) | ||
self.guest_model(dummy) | ||
shapes = [feat.shape for feat in self._features] | ||
return shapes | ||
|
||
def accept(self, guest_model, name_feat_layers2extract=None): | ||
self.guest_model = guest_model | ||
self.ref_model = copy.deepcopy(guest_model) | ||
self.register_feature_storage_hook(name_feat_layers2extract) | ||
|
||
def register_feature_storage_hook(self, feat_layers=None): | ||
# memorize features for each layer in self._feautres list | ||
if feat_layers is None: | ||
module = list(self.guest_model.children())[-1] | ||
module.register_forward_hook(self.hook) | ||
module_ref = list(self.ref_model.children())[-1] | ||
module_ref.register_forward_hook(self.hook_ref) | ||
else: | ||
for name, module in self.guest_model.named_modules(): | ||
if name in feat_layers: | ||
module.register_forward_hook(self.hook) | ||
self.flag_module_found = True | ||
|
||
if not self.flag_module_found: | ||
raise RuntimeError(f"{feat_layers} not found in model!") | ||
|
||
for name, module in self.ref_model.named_modules(): | ||
if name in feat_layers: | ||
module.register_forward_hook(self.hook_ref) | ||
|
||
def hook(self, module, input, output): | ||
self._features.append(output.detach()) | ||
|
||
def hook_ref(self, module, input, output): | ||
self._features_ref.append(output.detach()) | ||
|
||
def extract_intermediate_features(self, tensor_x, tensor_y, tensor_d, others=None): | ||
""" | ||
extract features for each layer of the neural network | ||
""" | ||
self.clear_features() | ||
self.guest_model(tensor_x) | ||
return self._features | ||
|
||
def clear_features(self): | ||
self._features.clear() | ||
|
||
def cal_feat_layers_ref_model(self, tensor_x, tensor_y, tensor_d, others=None): | ||
self._features_ref.clear() | ||
self.ref_model(tensor_x) | ||
return self._features_ref |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Laplace approximation for Mutual Information estimation | ||
""" | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
|
||
class MeanEncoder(nn.Module): | ||
"""Identity function""" | ||
def __init__(self, inter_layer_feat_shape): | ||
super().__init__() | ||
self.inter_layer_feat_shape = inter_layer_feat_shape | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
|
||
class VarianceEncoder(nn.Module): | ||
"""Bias-only model with diagonal covariance""" | ||
def __init__(self, inter_layer_feat_shape, init=0.1, eps=1e-5): | ||
super().__init__() | ||
self.inter_layer_feat_shape = inter_layer_feat_shape | ||
self.eps = eps | ||
|
||
init = (torch.as_tensor(init - eps).exp() - 1.0).log() | ||
b_shape = inter_layer_feat_shape | ||
self.b = nn.Parameter(torch.full(b_shape, init)) | ||
|
||
def forward(self, feat_layer_tensor_batch): | ||
""" | ||
train batch(population) level variance | ||
""" | ||
return F.softplus(self.b) + self.eps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
""" | ||
end-end test for mutual information regulation | ||
""" | ||
from tests.utils_test import utils_test_algo | ||
|
||
|
||
def test_miro(): | ||
""" | ||
train with MIRO | ||
""" | ||
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ | ||
--trainer=miro --nname=alexnet" | ||
utils_test_algo(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
end-end test for mutual information regulation | ||
""" | ||
from tests.utils_test import utils_test_algo | ||
|
||
|
||
def test_miro2(): | ||
""" | ||
train with MIRO | ||
""" | ||
args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \ | ||
--trainer=miro --nname=conv_bn_pool_2 \ | ||
--layers2extract_feats _net_invar_feat.conv_net.5" | ||
utils_test_algo(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
end-end test for mutual information regulation | ||
""" | ||
import pytest | ||
from tests.utils_test import utils_test_algo | ||
|
||
|
||
def test_miro3(): | ||
""" | ||
train with MIRO | ||
""" | ||
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ | ||
--trainer=miro --nname=alexnet \ | ||
--layers2extract_feats features" | ||
with pytest.raises(RuntimeError): | ||
utils_test_algo(args) | ||
raise RuntimeError("This is a runtime error") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters