Skip to content

Commit

Permalink
Merge pull request #878 from marrlab/miro
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun authored Oct 15, 2024
2 parents f4b9c1e + 7ab5e0a commit 8dafe68
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 4 deletions.
70 changes: 70 additions & 0 deletions docs/doc_miro.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# MIRO: Mutual-Information Regularization
## Mutual Information Regularization with Oracle (MIRO).

### Pre-requisite: Variational lower bound on mutual information

Barber, David, and Felix Agakov. "The im algorithm: a variational approach to information maximization." Advances in neural information processing systems 16, no. 320 (2004): 201.

$$I(X,Y)=H(Y)-H(Y|X)=-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log p(Y|X)\rangle}_{p(X,Y)}$$

Given variational distribution of $q(x|y)$ as decoder (i.e. $Y$ encodes information from $X$)


Since

$$KL\left(p(X|Y)|q(X|Y)\right)={\langle\log p(X|Y)\rangle}_{p(X|Y)}-\langle{\log q(X,Y)\rangle}_{p(X|Y)} >0$$

We have

$${\langle\log p(X|Y)\rangle}_{p(X|Y)}>{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$

Then

$$I(X,Y)=-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log p(Y|X)\rangle}_{p(X,Y)}>-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$

with the lower bound being

$$-{\langle\log p_y(Y)\rangle}_{p_y(Y)}+{\langle\log q(X,Y)\rangle}_{p(X|Y)}$$

To optimize the lower bound, one can iterate

- fix decoder $q(X|Y)$ and optimize encoder $Y=g(X;\theta) + \epsilon$
- fix encoder parameter $\theta$, tune decoder to alleviate the lower bound

#### Laplace approximation

decoding posterior:

$$p(X|Y) \sim Gaussian(Y|[\Sigma^{-1}]_{ij}=\frac{\partial^2 \log p(x|y)}{\partial x_i\partial x_j})$$

when $|Y|$ is large (large deviation from zero contains more information, which must be explained by non-typical $X$)

#### Linear Gaussian


The bound $H(X)+{\langle\log q(X|Y)\rangle}_{p(x,y)}$ becomes

$$\sum_i {\langle|X_i-m(Y_i)|_{|\Sigma^{-1}|(Y_i)} + \log det(\Sigma(Y_i))\rangle}_{p(Y|X)}$$


## MIRO

MIRO try to match the pre-trained model's features layer by layer to the target neural network we want to train for domain invariance in terms of mutual information. They use a constant identity encoder on feature from target neural network, then a population variance $\Sigma$ (forced to be diagonal).

Let $z$ denote the intermediate features of each layer, let $f_0$ be the pre-trained model, $f$ be the target neural network. Let $x$ be the input data.

$$z_f=f(x)$$

$$z_{f_0}=f^{(0)}(x)$$

the lower bound for Mutual information for instance $i$ is


$$\log|\Sigma| + ||z^{(i)}_{f_0}-id(z^{i})||_{\Sigma}^{-1}$$

where $id$ is the mean map

For diagonal $\Sigma$, determinant is simply multiplication of all diagonal values,

$$\log|\Sigma|=\sum_{k} \log \sigma_k + ||{z_k}^{(i)}_{f_0}-z_k^{i}||{\sigma_k}^{-1}$$

3 changes: 3 additions & 0 deletions domainlab/algos/trainers/a_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def __init__(self, successor_node=None, extend=None):
self._ma_iter = 0
#
self.list_reg_over_task_ratio = None
# MIRO
self.input_tensor_shape = None

@property
def model(self):
Expand Down Expand Up @@ -203,6 +205,7 @@ def cal_reg_loss_over_task_loss_ratio(self):
for ind_batch, (tensor_x, tensor_y, tensor_d, *others) in enumerate(
self.loader_tr
):
self.input_tensor_shape = tensor_x.shape
if ind_batch >= self.aconf.nb4reg_over_task_ratio:
return
tensor_x, tensor_y, tensor_d = (
Expand Down
17 changes: 17 additions & 0 deletions domainlab/algos/trainers/args_miro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
miro trainer configurations
"""


def add_args2parser_miro(parser):
"""
append hyper-parameters to the main argparser
"""
arg_group_miro = parser.add_argument_group("miro")
arg_group_miro.add_argument(
"--layers2extract_feats",
nargs="*",
default=None,
help="layer names separated by space to extract features",
)
return parser
63 changes: 63 additions & 0 deletions domainlab/algos/trainers/train_miro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
author: Kakao Brain.
# https://arxiv.org/pdf/2203.10789#page=3.77
# [aut] xudong, alexej
"""

import torch
from torch import nn
from domainlab.algos.trainers.train_basic import TrainerBasic
from domainlab.algos.trainers.train_miro_utils import \
MeanEncoder, VarianceEncoder
from domainlab.algos.trainers.train_miro_model_wraper import \
TrainerMiroModelWraper


class TrainerMiro(TrainerBasic):
"""Mutual-Information Regularization with Oracle"""
def before_tr(self):
self.model_wraper = TrainerMiroModelWraper()
self.model_wraper.accept(self.model,
name_feat_layers2extract=self.aconf.layers2extract_feats)
self.mean_encoders = None
self.var_encoders = None
super().before_tr()
shapes = self.model_wraper.get_shapes(self.input_tensor_shape)
self.mean_encoders = nn.ModuleList([
MeanEncoder(shape) for shape in shapes
])
self.var_encoders = nn.ModuleList([
VarianceEncoder(shape) for shape in shapes
])

def _cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
# list_batch_inter_feat_new are features for each layer
list_batch_inter_feat_new = \
self.model_wraper.extract_intermediate_features(
tensor_x, tensor_y, tensor_d, others)

# reference model
with torch.no_grad():
list_batch_inter_feat_ref = self.model_wraper.cal_feat_layers_ref_model(
tensor_x, tensor_y, tensor_d, others)
# dim(list_batch_inter_feat_ref)=[size_batch, dim_feat]
if self.mean_encoders is None:
device = tensor_x.device
return [torch.zeros(tensor_x.shape[0]).to(device)], [self.aconf.gamma_reg]

reg_loss = 0.
num_layers = len(self.mean_encoders)
device = tensor_x.device
for ind_layer in range(num_layers):
# layerwise mutual information regularization
mean_encoder = self.mean_encoders[ind_layer].to(device)
feat = list_batch_inter_feat_new[ind_layer]
feat = feat.to(device)
mean = mean_encoder(feat)
var_encoder = self.var_encoders[ind_layer].to(device)
var = var_encoder(feat)
mean_ref = list_batch_inter_feat_ref[ind_layer]
mean_ref = mean_ref.to(device)
vlb = (mean - mean_ref).pow(2).div(var) + var.log()
reg_loss += vlb.mean(dim=tuple(range(1, vlb.dim()))) / 2.
return [reg_loss], [self.aconf.gamma_reg]
72 changes: 72 additions & 0 deletions domainlab/algos/trainers/train_miro_model_wraper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
https://arxiv.org/pdf/2203.10789#page=3.77
"""
import copy
import torch
from torch import nn


class TrainerMiroModelWraper():
"""Mutual-Information Regularization with Oracle"""
def __init__(self):
self._features = []
self._features_ref = []
self.guest_model = None
self.ref_model = None
self.flag_module_found = False

def get_shapes(self, input_shape):
# get shape of intermediate features
self.clear_features()
with torch.no_grad():
dummy = torch.rand(*input_shape).to(next(self.guest_model.parameters()).device)
self.guest_model(dummy)
shapes = [feat.shape for feat in self._features]
return shapes

def accept(self, guest_model, name_feat_layers2extract=None):
self.guest_model = guest_model
self.ref_model = copy.deepcopy(guest_model)
self.register_feature_storage_hook(name_feat_layers2extract)

def register_feature_storage_hook(self, feat_layers=None):
# memorize features for each layer in self._feautres list
if feat_layers is None:
module = list(self.guest_model.children())[-1]
module.register_forward_hook(self.hook)
module_ref = list(self.ref_model.children())[-1]
module_ref.register_forward_hook(self.hook_ref)
else:
for name, module in self.guest_model.named_modules():
if name in feat_layers:
module.register_forward_hook(self.hook)
self.flag_module_found = True

if not self.flag_module_found:
raise RuntimeError(f"{feat_layers} not found in model!")

for name, module in self.ref_model.named_modules():
if name in feat_layers:
module.register_forward_hook(self.hook_ref)

def hook(self, module, input, output):
self._features.append(output.detach())

def hook_ref(self, module, input, output):
self._features_ref.append(output.detach())

def extract_intermediate_features(self, tensor_x, tensor_y, tensor_d, others=None):
"""
extract features for each layer of the neural network
"""
self.clear_features()
self.guest_model(tensor_x)
return self._features

def clear_features(self):
self._features.clear()

def cal_feat_layers_ref_model(self, tensor_x, tensor_y, tensor_d, others=None):
self._features_ref.clear()
self.ref_model(tensor_x)
return self._features_ref
34 changes: 34 additions & 0 deletions domainlab/algos/trainers/train_miro_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Laplace approximation for Mutual Information estimation
"""
import torch
import torch.nn.functional as F
from torch import nn


class MeanEncoder(nn.Module):
"""Identity function"""
def __init__(self, inter_layer_feat_shape):
super().__init__()
self.inter_layer_feat_shape = inter_layer_feat_shape

def forward(self, x):
return x


class VarianceEncoder(nn.Module):
"""Bias-only model with diagonal covariance"""
def __init__(self, inter_layer_feat_shape, init=0.1, eps=1e-5):
super().__init__()
self.inter_layer_feat_shape = inter_layer_feat_shape
self.eps = eps

init = (torch.as_tensor(init - eps).exp() - 1.0).log()
b_shape = inter_layer_feat_shape
self.b = nn.Parameter(torch.full(b_shape, init))

def forward(self, feat_layer_tensor_batch):
"""
train batch(population) level variance
"""
return F.softplus(self.b) + self.eps
2 changes: 2 additions & 0 deletions domainlab/algos/trainers/zoo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from domainlab.algos.trainers.train_irm import TrainerIRM
from domainlab.algos.trainers.train_causIRL import TrainerCausalIRL
from domainlab.algos.trainers.train_coral import TrainerCoral
from domainlab.algos.trainers.train_miro import TrainerMiro


class TrainerChainNodeGetter(object):
Expand Down Expand Up @@ -59,6 +60,7 @@ def __call__(self, lst_candidates=None, default=None, lst_excludes=None):
chain = TrainerHyperScheduler(chain)
chain = TrainerCausalIRL(chain)
chain = TrainerCoral(chain)
chain = TrainerMiro(chain)
node = chain.handle(self.request)
head = node
while self._list_str_trainer:
Expand Down
3 changes: 3 additions & 0 deletions domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from domainlab.algos.trainers.args_dial import add_args2parser_dial
from domainlab.algos.trainers.compos.matchdg_args import add_args2parser_matchdg
from domainlab.algos.trainers.args_miro import add_args2parser_miro
from domainlab.models.args_jigen import add_args2parser_jigen
from domainlab.models.args_vae import add_args2parser_vae
from domainlab.utils.logger import Logger
Expand Down Expand Up @@ -356,6 +357,8 @@ def mk_parser_main():
arg_group_vae = add_args2parser_vae(arg_group_vae)
arg_group_matchdg = parser.add_argument_group("matchdg")
arg_group_matchdg = add_args2parser_matchdg(arg_group_matchdg)
arg_group_miro = parser.add_argument_group("miro")
arg_group_miro = add_args2parser_miro(arg_group_miro)
arg_group_jigen = parser.add_argument_group("jigen")
arg_group_jigen = add_args2parser_jigen(arg_group_jigen)
args_group_dial = parser.add_argument_group("dial")
Expand Down
7 changes: 4 additions & 3 deletions domainlab/models/a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,15 @@ def _extend_loss(self, tensor_x, tensor_y, tensor_d, others=None):
return self._decoratee.cal_reg_loss(tensor_x, tensor_y, tensor_d, others)
return None, None

def forward(self, tensor_x, tensor_y, tensor_d, others=None):
def forward(self, tensor_x):
"""forward.
:param x:
:param y:
:param d:
"""
return self.cal_loss(tensor_x, tensor_y, tensor_d, others)
out = self.extract_semantic_feat(tensor_x)
return out

def extract_semantic_feat(self, tensor_x):
"""
Expand Down Expand Up @@ -205,7 +206,7 @@ def name(self):

def print_parameters(self):
"""
Function to print all parameters of the object.
Function to print all parameters of the object.
Can be used to print the parameters of every child class.
"""
params = vars(self)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_miro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
end-end test for mutual information regulation
"""
from tests.utils_test import utils_test_algo


def test_miro():
"""
train with MIRO
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=miro --nname=alexnet"
utils_test_algo(args)
14 changes: 14 additions & 0 deletions tests/test_miro2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
end-end test for mutual information regulation
"""
from tests.utils_test import utils_test_algo


def test_miro2():
"""
train with MIRO
"""
args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \
--trainer=miro --nname=conv_bn_pool_2 \
--layers2extract_feats _net_invar_feat.conv_net.5"
utils_test_algo(args)
17 changes: 17 additions & 0 deletions tests/test_miro3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
end-end test for mutual information regulation
"""
import pytest
from tests.utils_test import utils_test_algo


def test_miro3():
"""
train with MIRO
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=miro --nname=alexnet \
--layers2extract_feats features"
with pytest.raises(RuntimeError):
utils_test_algo(args)
raise RuntimeError("This is a runtime error")
2 changes: 1 addition & 1 deletion tests/test_model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def test_model_diva():
)
imgs, y_s, d_s = mk_rand_xyd(28, y_dim, 2, 2)
_, _, _, _, _ = model.infer_y_vpicn(imgs)
model(imgs, y_s, d_s)
model.cal_loss(imgs, y_s, d_s)

0 comments on commit 8dafe68

Please sign in to comment.