-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
266 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
|
||
''' | ||
@file: MPNCOV.py | ||
@author: Jiangtao Xie | ||
@author: Peihua Li | ||
Please cite the paper below if you use the code: | ||
Peihua Li, Jiangtao Xie, Qilong Wang and Zilin Gao. Towards Faster Training of Global Covariance Pooling Networks by Iterative Matrix Square Root Normalization. IEEE Int. Conf. on Computer Vision and Pattern Recognition (CVPR), pp. 947-955, 2018. | ||
Peihua Li, Jiangtao Xie, Qilong Wang and Wangmeng Zuo. Is Second-order Information Helpful for Large-scale Visual Recognition? IEEE Int. Conf. on Computer Vision (ICCV), pp. 2070-2078, 2017. | ||
Copyright (C) 2018 Peihua Li and Jiangtao Xie | ||
All rights reserved. | ||
''' | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Function | ||
|
||
class MPNCOV(nn.Module): | ||
"""Matrix power normalized Covariance pooling (MPNCOV) | ||
implementation of fast MPN-COV (i.e.,iSQRT-COV) | ||
https://arxiv.org/abs/1712.01034 | ||
Args: | ||
iterNum: #iteration of Newton-schulz method | ||
is_sqrt: whether perform matrix square root or not | ||
is_vec: whether the output is a vector or not | ||
input_dim: the #channel of input feature | ||
dimension_reduction: if None, it will not use 1x1 conv to | ||
reduce the #channel of feature. | ||
if 256 or others, the #channel of feature | ||
will be reduced to 256 or others. | ||
""" | ||
def __init__(self, iterNum=3, is_sqrt=True, is_vec=True, input_dim=2048, dimension_reduction=None): | ||
|
||
super(MPNCOV, self).__init__() | ||
self.iterNum=iterNum | ||
self.is_sqrt = is_sqrt | ||
self.is_vec = is_vec | ||
self.dr = dimension_reduction | ||
if self.dr is not None: | ||
self.conv_dr_block = nn.Sequential( | ||
nn.Conv2d(input_dim, self.dr, kernel_size=1, stride=1, bias=False), | ||
nn.BatchNorm2d(self.dr), | ||
nn.ReLU(inplace=True) | ||
) | ||
output_dim = self.dr if self.dr else input_dim | ||
if self.is_vec: | ||
self.output_dim = int(output_dim*(output_dim+1)/2) | ||
else: | ||
self.output_dim = int(output_dim*output_dim) | ||
self._init_weight() | ||
|
||
def _init_weight(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def _cov_pool(self, x): | ||
return Covpool.apply(x) | ||
def _sqrtm(self, x): | ||
return Sqrtm.apply(x, self.iterNum) | ||
def _triuvec(self, x): | ||
return Triuvec.apply(x) | ||
|
||
def forward(self, x): | ||
if self.dr is not None: | ||
x = self.conv_dr_block(x) | ||
x = self._cov_pool(x) | ||
if self.is_sqrt: | ||
x = self._sqrtm(x) | ||
if self.is_vec: | ||
x = self._triuvec(x) | ||
return x | ||
|
||
|
||
class Covpool(Function): | ||
@staticmethod | ||
def forward(ctx, input): | ||
x = input | ||
batchSize = x.data.shape[0] | ||
dim = x.data.shape[1] | ||
h = x.data.shape[2] | ||
w = x.data.shape[3] | ||
M = h*w | ||
x = x.reshape(batchSize,dim,M) | ||
I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) | ||
I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) | ||
y = x.bmm(I_hat).bmm(x.transpose(1,2)) | ||
ctx.save_for_backward(input,I_hat) | ||
return y | ||
@staticmethod | ||
def backward(ctx, grad_output): | ||
input,I_hat = ctx.saved_tensors | ||
x = input | ||
batchSize = x.data.shape[0] | ||
dim = x.data.shape[1] | ||
h = x.data.shape[2] | ||
w = x.data.shape[3] | ||
M = h*w | ||
x = x.reshape(batchSize,dim,M) | ||
grad_input = grad_output + grad_output.transpose(1,2) | ||
grad_input = grad_input.bmm(x).bmm(I_hat) | ||
grad_input = grad_input.reshape(batchSize,dim,h,w) | ||
return grad_input | ||
|
||
class Sqrtm(Function): | ||
@staticmethod | ||
def forward(ctx, input, iterN): | ||
x = input | ||
batchSize = x.data.shape[0] | ||
dim = x.data.shape[1] | ||
dtype = x.dtype | ||
I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) | ||
normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) | ||
A = x.div(normA.view(batchSize,1,1).expand_as(x)) | ||
Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device).type(dtype) | ||
Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1).type(dtype) | ||
if iterN < 2: | ||
ZY = 0.5*(I3 - A) | ||
YZY = A.bmm(ZY) | ||
else: | ||
ZY = 0.5*(I3 - A) | ||
Y[:,0,:,:] = A.bmm(ZY) | ||
Z[:,0,:,:] = ZY | ||
for i in range(1, iterN-1): | ||
ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) | ||
Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) | ||
Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) | ||
YZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) | ||
y = YZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) | ||
ctx.save_for_backward(input, A, YZY, normA, Y, Z) | ||
ctx.iterN = iterN | ||
return y | ||
@staticmethod | ||
def backward(ctx, grad_output): | ||
input, A, ZY, normA, Y, Z = ctx.saved_tensors | ||
iterN = ctx.iterN | ||
x = input | ||
batchSize = x.data.shape[0] | ||
dim = x.data.shape[1] | ||
dtype = x.dtype | ||
der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) | ||
der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) | ||
I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) | ||
if iterN < 2: | ||
der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_postCom)) | ||
else: | ||
dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - | ||
Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) | ||
dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) | ||
for i in range(iterN-3, -1, -1): | ||
YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) | ||
ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) | ||
dldY_ = 0.5*(dldY.bmm(YZ) - | ||
Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - | ||
ZY.bmm(dldY)) | ||
dldZ_ = 0.5*(YZ.bmm(dldZ) - | ||
Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - | ||
dldZ.bmm(ZY)) | ||
dldY = dldY_ | ||
dldZ = dldZ_ | ||
der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) | ||
der_NSiter = der_NSiter.transpose(1, 2) | ||
grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) | ||
grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) | ||
for i in range(batchSize): | ||
grad_input[i,:,:] += (der_postComAux[i] \ | ||
- grad_aux[i] / (normA[i] * normA[i])) \ | ||
*torch.ones(dim,device = x.device).diag().type(dtype) | ||
return grad_input, None | ||
|
||
class Triuvec(Function): | ||
@staticmethod | ||
def forward(ctx, input): | ||
x = input | ||
batchSize = x.data.shape[0] | ||
dim = x.data.shape[1] | ||
dtype = x.dtype | ||
x = x.reshape(batchSize, dim*dim) | ||
I = torch.ones(dim,dim).triu().reshape(dim*dim) | ||
index = I.nonzero() | ||
y = torch.zeros(batchSize,int(dim*(dim+1)/2),device = x.device).type(dtype) | ||
y = x[:,index] | ||
ctx.save_for_backward(input,index) | ||
return y | ||
@staticmethod | ||
def backward(ctx, grad_output): | ||
input,index = ctx.saved_tensors | ||
x = input | ||
batchSize = x.data.shape[0] | ||
dim = x.data.shape[1] | ||
dtype = x.dtype | ||
grad_input = torch.zeros(batchSize,dim*dim,device = x.device,requires_grad=False).type(dtype) | ||
grad_input[:,index] = grad_output | ||
grad_input = grad_input.reshape(batchSize,dim,dim) | ||
return grad_input | ||
|
||
def CovpoolLayer(var): | ||
return Covpool.apply(var) | ||
|
||
def SqrtmLayer(var, iterN): | ||
return Sqrtm.apply(var, iterN) | ||
|
||
def TriuvecLayer(var): | ||
return Triuvec.apply(var) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class NetVLAD(nn.Module): | ||
"""NetVLAD layer implementation""" | ||
|
||
def __init__(self, num_clusters=64, dim=128, alpha=100.0, | ||
normalize_input=True): | ||
""" | ||
Args: | ||
num_clusters : int | ||
The number of clusters | ||
dim : int | ||
Dimension of descriptors | ||
alpha : float | ||
Parameter of initialization. Larger value is harder assignment. | ||
normalize_input : bool | ||
If true, descriptor-wise L2 normalization is applied to input. | ||
""" | ||
super(NetVLAD, self).__init__() | ||
self.num_clusters = num_clusters | ||
self.dim = dim | ||
self.alpha = alpha | ||
self.normalize_input = normalize_input | ||
self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True) | ||
self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) | ||
self._init_params() | ||
|
||
def _init_params(self): | ||
self.conv.weight = nn.Parameter( | ||
(2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) | ||
) | ||
self.conv.bias = nn.Parameter( | ||
- self.alpha * self.centroids.norm(dim=1) | ||
) | ||
|
||
def forward(self, x): | ||
N, C = x.shape[:2] | ||
|
||
if self.normalize_input: | ||
x = F.normalize(x, p=2, dim=1) # across descriptor dim | ||
|
||
# soft-assignment | ||
soft_assign = self.conv(x).view(N, self.num_clusters, -1) | ||
soft_assign = F.softmax(soft_assign, dim=1) | ||
|
||
x_flatten = x.view(N, C, -1) | ||
|
||
# calculate residuals to each clusters | ||
residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ | ||
self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) | ||
residual *= soft_assign.unsqueeze(2) | ||
vlad = residual.sum(dim=-1) | ||
|
||
vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization | ||
vlad = vlad.view(x.size(0), -1) # flatten | ||
vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize | ||
|
||
return vlad |