From 49efc256ee5dad6c4c0fbc68134c7d4056ebfb02 Mon Sep 17 00:00:00 2001 From: Onionbao <411965697@qq.com> Date: Fri, 29 Mar 2019 14:58:12 +0800 Subject: [PATCH] Add files via upload --- BCNN.py | 43 ++++++++++++++++++++++++++++ CBP.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++ TRILINEAR.py | 42 ++++++++++++++++++++++++++++ 3 files changed, 164 insertions(+) create mode 100644 BCNN.py create mode 100644 CBP.py create mode 100644 TRILINEAR.py diff --git a/BCNN.py b/BCNN.py new file mode 100644 index 0000000..9b2273f --- /dev/null +++ b/BCNN.py @@ -0,0 +1,43 @@ +''' +@file: BCNN.py +@author: Jiangtao Xie +@author: Peihua Li +''' +import torch +import torch.nn as nn + +class BCNN(nn.Module): + """Bilinear Pool + implementation of Bilinear CNN (BCNN) + https://arxiv.org/abs/1504.07889v5 + Args: + thresh: small positive number for computation stability + is_vec: whether the output is a vector or not + input_dim: the #channel of input feature + """ + def __init__(self, thresh=1e-8, is_vec=True, input_dim=2048): + super(BCNN, self).__init__() + self.thresh = thresh + self.is_vec = is_vec + self.output_dim = input_dim * input_dim + def _bilinearpool(self, x): + batchSize, dim, h, w = x.data.shape + x = x.reshape(batchSize, dim, h * w) + x = 1. / (h * w) * x.bmm(x.transpose(1, 2)) + return x + + def _signed_sqrt(self, x): + x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh)) + return x + + def _l2norm(self, x): + x = nn.functional.normalize(x) + return x + + def forward(self, x): + x = self._bilinearpool(x) + x = self._signed_sqrt(x) + if self.is_vec: + x = x.view(x.size(0),-1) + x = self._l2norm(x) + return x \ No newline at end of file diff --git a/CBP.py b/CBP.py new file mode 100644 index 0000000..014e5b5 --- /dev/null +++ b/CBP.py @@ -0,0 +1,79 @@ +''' +@file: CBP.py +@author: Chunqiao Xu +@author: Jiangtao Xie +@author: Peihua Li +''' +import torch +import torch.nn as nn +class CBP(nn.Module): + """Compact Bilinear Pooling + implementation of Compact Bilinear Pooling (CBP) + https://arxiv.org/pdf/1511.06062.pdf + Args: + thresh: small positive number for computation stability + projDim: projected dimension + input_dim: the #channel of input feature + """ + def __init__(self, thresh=1e-8, projDim=8192, input_dim=512): + super(CBP, self).__init__() + self.thresh = thresh + self.projDim = projDim + self.input_dim = input_dim + self.output_dim = projDim + torch.manual_seed(1) + self.h_ = [ + torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long), + torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long) + ] + self.weights_ = [ + (2 * torch.randint(0, 2, (self.input_dim,)) - 1).float(), + (2 * torch.randint(0, 2, (self.input_dim,)) - 1).float() + ] + + indices1 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1), + self.h_[0].reshape(1, -1)), dim=0) + indices2 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1), + self.h_[1].reshape(1, -1)), dim=0) + + self.sparseM = [ + torch.sparse.FloatTensor(indices1, self.weights_[0], torch.Size([self.input_dim, self.output_dim])).to_dense(), + torch.sparse.FloatTensor(indices2, self.weights_[1], torch.Size([self.input_dim, self.output_dim])).to_dense(), + ] + def _signed_sqrt(self, x): + x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh)) + return x + + def _l2norm(self, x): + x = nn.functional.normalize(x) + return x + + def forward(self, x): + bsn = 1 + batchSize, dim, h, w = x.data.shape + x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim, + y = torch.ones(batchSize, self.output_dim, device=x.device) + + for img in range(batchSize // bsn): + segLen = bsn * h * w + upper = batchSize * h * w + interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long) + interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long) + batch_x = x_flat[interLarge, :] + + sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2) + sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1) + + sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2) + sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1) + + Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1]) + Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0]) + + tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0] + + y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1) + + y = self._signed_sqrt(y) + y = self._l2norm(y) + return y \ No newline at end of file diff --git a/TRILINEAR.py b/TRILINEAR.py new file mode 100644 index 0000000..24a7a6f --- /dev/null +++ b/TRILINEAR.py @@ -0,0 +1,42 @@ +''' +@file: BCNN.py +@author: Jiangtao Xie +@author: Peihua Li +''' +import torch +import torch.nn as nn +import torch.nn.functional as F + +class TRILINEAR(nn.Module): + + def __init__(self, input_dim=2048): + super(TRILINEAR, self).__init__() + #self.thresh = thresh + #self.is_vec = is_vec + self.output_dim = input_dim + def _trilinearpool(self, x): + batchSize, dim, h, w = x.data.shape + x = x.reshape(batchSize, dim, h * w) + #x = 1. / (h * w) * x.bmm(x.transpose(1, 2)) + x_norm = F.softmax(dim=2) + channel_relation = x_norm.bmm(x.transpose(1, 2)) #inter-channel relationship map + #channel_relation = F.softmax(channel_relation) + x = channel_relation.bmm(x) #trilinear attention map: b*c*(h*w) + x = x.mean(2) + return x + + #def _signed_sqrt(self, x): + # x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh)) + # return x + + #def _l2norm(self, x): + # x = F.normalize(x) + # return x + + def forward(self, x): + x = self._trilinearpool(x) + #x = self._signed_sqrt(x) + #if self.is_vec: + # x = x.view(x.size(0),-1) + #x = self._l2norm(x) + return x \ No newline at end of file