diff --git a/MPNCOV.py b/MPNCOV.py new file mode 100644 index 0000000..5e9678e --- /dev/null +++ b/MPNCOV.py @@ -0,0 +1,205 @@ + +''' +@file: MPNCOV.py +@author: Jiangtao Xie +@author: Peihua Li +Please cite the paper below if you use the code: +Peihua Li, Jiangtao Xie, Qilong Wang and Zilin Gao. Towards Faster Training of Global Covariance Pooling Networks by Iterative Matrix Square Root Normalization. IEEE Int. Conf. on Computer Vision and Pattern Recognition (CVPR), pp. 947-955, 2018. +Peihua Li, Jiangtao Xie, Qilong Wang and Wangmeng Zuo. Is Second-order Information Helpful for Large-scale Visual Recognition? IEEE Int. Conf. on Computer Vision (ICCV), pp. 2070-2078, 2017. +Copyright (C) 2018 Peihua Li and Jiangtao Xie +All rights reserved. +''' +import torch +import torch.nn as nn +from torch.autograd import Function + +class MPNCOV(nn.Module): + """Matrix power normalized Covariance pooling (MPNCOV) + implementation of fast MPN-COV (i.e.,iSQRT-COV) + https://arxiv.org/abs/1712.01034 + Args: + iterNum: #iteration of Newton-schulz method + is_sqrt: whether perform matrix square root or not + is_vec: whether the output is a vector or not + input_dim: the #channel of input feature + dimension_reduction: if None, it will not use 1x1 conv to + reduce the #channel of feature. + if 256 or others, the #channel of feature + will be reduced to 256 or others. + """ + def __init__(self, iterNum=3, is_sqrt=True, is_vec=True, input_dim=2048, dimension_reduction=None): + + super(MPNCOV, self).__init__() + self.iterNum=iterNum + self.is_sqrt = is_sqrt + self.is_vec = is_vec + self.dr = dimension_reduction + if self.dr is not None: + self.conv_dr_block = nn.Sequential( + nn.Conv2d(input_dim, self.dr, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(self.dr), + nn.ReLU(inplace=True) + ) + output_dim = self.dr if self.dr else input_dim + if self.is_vec: + self.output_dim = int(output_dim*(output_dim+1)/2) + else: + self.output_dim = int(output_dim*output_dim) + self._init_weight() + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _cov_pool(self, x): + return Covpool.apply(x) + def _sqrtm(self, x): + return Sqrtm.apply(x, self.iterNum) + def _triuvec(self, x): + return Triuvec.apply(x) + + def forward(self, x): + if self.dr is not None: + x = self.conv_dr_block(x) + x = self._cov_pool(x) + if self.is_sqrt: + x = self._sqrtm(x) + if self.is_vec: + x = self._triuvec(x) + return x + + +class Covpool(Function): + @staticmethod + def forward(ctx, input): + x = input + batchSize = x.data.shape[0] + dim = x.data.shape[1] + h = x.data.shape[2] + w = x.data.shape[3] + M = h*w + x = x.reshape(batchSize,dim,M) + I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) + I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) + y = x.bmm(I_hat).bmm(x.transpose(1,2)) + ctx.save_for_backward(input,I_hat) + return y + @staticmethod + def backward(ctx, grad_output): + input,I_hat = ctx.saved_tensors + x = input + batchSize = x.data.shape[0] + dim = x.data.shape[1] + h = x.data.shape[2] + w = x.data.shape[3] + M = h*w + x = x.reshape(batchSize,dim,M) + grad_input = grad_output + grad_output.transpose(1,2) + grad_input = grad_input.bmm(x).bmm(I_hat) + grad_input = grad_input.reshape(batchSize,dim,h,w) + return grad_input + +class Sqrtm(Function): + @staticmethod + def forward(ctx, input, iterN): + x = input + batchSize = x.data.shape[0] + dim = x.data.shape[1] + dtype = x.dtype + I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) + normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) + A = x.div(normA.view(batchSize,1,1).expand_as(x)) + Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device).type(dtype) + Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1).type(dtype) + if iterN < 2: + ZY = 0.5*(I3 - A) + YZY = A.bmm(ZY) + else: + ZY = 0.5*(I3 - A) + Y[:,0,:,:] = A.bmm(ZY) + Z[:,0,:,:] = ZY + for i in range(1, iterN-1): + ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) + Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) + Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) + YZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) + y = YZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) + ctx.save_for_backward(input, A, YZY, normA, Y, Z) + ctx.iterN = iterN + return y + @staticmethod + def backward(ctx, grad_output): + input, A, ZY, normA, Y, Z = ctx.saved_tensors + iterN = ctx.iterN + x = input + batchSize = x.data.shape[0] + dim = x.data.shape[1] + dtype = x.dtype + der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) + der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) + I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) + if iterN < 2: + der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_postCom)) + else: + dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - + Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) + dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) + for i in range(iterN-3, -1, -1): + YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) + ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) + dldY_ = 0.5*(dldY.bmm(YZ) - + Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - + ZY.bmm(dldY)) + dldZ_ = 0.5*(YZ.bmm(dldZ) - + Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - + dldZ.bmm(ZY)) + dldY = dldY_ + dldZ = dldZ_ + der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) + der_NSiter = der_NSiter.transpose(1, 2) + grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) + grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) + for i in range(batchSize): + grad_input[i,:,:] += (der_postComAux[i] \ + - grad_aux[i] / (normA[i] * normA[i])) \ + *torch.ones(dim,device = x.device).diag().type(dtype) + return grad_input, None + +class Triuvec(Function): + @staticmethod + def forward(ctx, input): + x = input + batchSize = x.data.shape[0] + dim = x.data.shape[1] + dtype = x.dtype + x = x.reshape(batchSize, dim*dim) + I = torch.ones(dim,dim).triu().reshape(dim*dim) + index = I.nonzero() + y = torch.zeros(batchSize,int(dim*(dim+1)/2),device = x.device).type(dtype) + y = x[:,index] + ctx.save_for_backward(input,index) + return y + @staticmethod + def backward(ctx, grad_output): + input,index = ctx.saved_tensors + x = input + batchSize = x.data.shape[0] + dim = x.data.shape[1] + dtype = x.dtype + grad_input = torch.zeros(batchSize,dim*dim,device = x.device,requires_grad=False).type(dtype) + grad_input[:,index] = grad_output + grad_input = grad_input.reshape(batchSize,dim,dim) + return grad_input + +def CovpoolLayer(var): + return Covpool.apply(var) + +def SqrtmLayer(var, iterN): + return Sqrtm.apply(var, iterN) + +def TriuvecLayer(var): + return Triuvec.apply(var) \ No newline at end of file diff --git a/NetVLAD.py b/NetVLAD.py new file mode 100644 index 0000000..4c875b2 --- /dev/null +++ b/NetVLAD.py @@ -0,0 +1,61 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class NetVLAD(nn.Module): + """NetVLAD layer implementation""" + + def __init__(self, num_clusters=64, dim=128, alpha=100.0, + normalize_input=True): + """ + Args: + num_clusters : int + The number of clusters + dim : int + Dimension of descriptors + alpha : float + Parameter of initialization. Larger value is harder assignment. + normalize_input : bool + If true, descriptor-wise L2 normalization is applied to input. + """ + super(NetVLAD, self).__init__() + self.num_clusters = num_clusters + self.dim = dim + self.alpha = alpha + self.normalize_input = normalize_input + self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True) + self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) + self._init_params() + + def _init_params(self): + self.conv.weight = nn.Parameter( + (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) + ) + self.conv.bias = nn.Parameter( + - self.alpha * self.centroids.norm(dim=1) + ) + + def forward(self, x): + N, C = x.shape[:2] + + if self.normalize_input: + x = F.normalize(x, p=2, dim=1) # across descriptor dim + + # soft-assignment + soft_assign = self.conv(x).view(N, self.num_clusters, -1) + soft_assign = F.softmax(soft_assign, dim=1) + + x_flatten = x.view(N, C, -1) + + # calculate residuals to each clusters + residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ + self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) + residual *= soft_assign.unsqueeze(2) + vlad = residual.sum(dim=-1) + + vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization + vlad = vlad.view(x.size(0), -1) # flatten + vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize + + return vlad