Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
zwx8981 authored Mar 29, 2019
1 parent 5324ef5 commit 49efc25
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 0 deletions.
43 changes: 43 additions & 0 deletions BCNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
'''
@file: BCNN.py
@author: Jiangtao Xie
@author: Peihua Li
'''
import torch
import torch.nn as nn

class BCNN(nn.Module):
"""Bilinear Pool
implementation of Bilinear CNN (BCNN)
https://arxiv.org/abs/1504.07889v5
Args:
thresh: small positive number for computation stability
is_vec: whether the output is a vector or not
input_dim: the #channel of input feature
"""
def __init__(self, thresh=1e-8, is_vec=True, input_dim=2048):
super(BCNN, self).__init__()
self.thresh = thresh
self.is_vec = is_vec
self.output_dim = input_dim * input_dim
def _bilinearpool(self, x):
batchSize, dim, h, w = x.data.shape
x = x.reshape(batchSize, dim, h * w)
x = 1. / (h * w) * x.bmm(x.transpose(1, 2))
return x

def _signed_sqrt(self, x):
x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh))
return x

def _l2norm(self, x):
x = nn.functional.normalize(x)
return x

def forward(self, x):
x = self._bilinearpool(x)
x = self._signed_sqrt(x)
if self.is_vec:
x = x.view(x.size(0),-1)
x = self._l2norm(x)
return x
79 changes: 79 additions & 0 deletions CBP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
'''
@file: CBP.py
@author: Chunqiao Xu
@author: Jiangtao Xie
@author: Peihua Li
'''
import torch
import torch.nn as nn
class CBP(nn.Module):
"""Compact Bilinear Pooling
implementation of Compact Bilinear Pooling (CBP)
https://arxiv.org/pdf/1511.06062.pdf
Args:
thresh: small positive number for computation stability
projDim: projected dimension
input_dim: the #channel of input feature
"""
def __init__(self, thresh=1e-8, projDim=8192, input_dim=512):
super(CBP, self).__init__()
self.thresh = thresh
self.projDim = projDim
self.input_dim = input_dim
self.output_dim = projDim
torch.manual_seed(1)
self.h_ = [
torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long),
torch.randint(0, self.output_dim, (self.input_dim,),dtype=torch.long)
]
self.weights_ = [
(2 * torch.randint(0, 2, (self.input_dim,)) - 1).float(),
(2 * torch.randint(0, 2, (self.input_dim,)) - 1).float()
]

indices1 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1),
self.h_[0].reshape(1, -1)), dim=0)
indices2 = torch.cat((torch.arange(input_dim, dtype=torch.long).reshape(1, -1),
self.h_[1].reshape(1, -1)), dim=0)

self.sparseM = [
torch.sparse.FloatTensor(indices1, self.weights_[0], torch.Size([self.input_dim, self.output_dim])).to_dense(),
torch.sparse.FloatTensor(indices2, self.weights_[1], torch.Size([self.input_dim, self.output_dim])).to_dense(),
]
def _signed_sqrt(self, x):
x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh))
return x

def _l2norm(self, x):
x = nn.functional.normalize(x)
return x

def forward(self, x):
bsn = 1
batchSize, dim, h, w = x.data.shape
x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, dim) # batchsize,h, w, dim,
y = torch.ones(batchSize, self.output_dim, device=x.device)

for img in range(batchSize // bsn):
segLen = bsn * h * w
upper = batchSize * h * w
interLarge = torch.arange(img * segLen, min(upper, (img + 1) * segLen), dtype=torch.long)
interSmall = torch.arange(img * bsn, min(upper, (img + 1) * bsn), dtype=torch.long)
batch_x = x_flat[interLarge, :]

sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
sketch1 = torch.fft(torch.cat((sketch1, torch.zeros(sketch1.size(), device=x.device)), dim=2), 1)

sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
sketch2 = torch.fft(torch.cat((sketch2, torch.zeros(sketch2.size(), device=x.device)), dim=2), 1)

Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(sketch2[:, :, 1])
Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(sketch2[:, :, 0])

tmp_y = torch.ifft(torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :, 0]

y[interSmall, :] = tmp_y.view(torch.numel(interSmall), h, w, self.output_dim).sum(dim=1).sum(dim=1)

y = self._signed_sqrt(y)
y = self._l2norm(y)
return y
42 changes: 42 additions & 0 deletions TRILINEAR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
'''
@file: BCNN.py
@author: Jiangtao Xie
@author: Peihua Li
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

class TRILINEAR(nn.Module):

def __init__(self, input_dim=2048):
super(TRILINEAR, self).__init__()
#self.thresh = thresh
#self.is_vec = is_vec
self.output_dim = input_dim
def _trilinearpool(self, x):
batchSize, dim, h, w = x.data.shape
x = x.reshape(batchSize, dim, h * w)
#x = 1. / (h * w) * x.bmm(x.transpose(1, 2))
x_norm = F.softmax(dim=2)
channel_relation = x_norm.bmm(x.transpose(1, 2)) #inter-channel relationship map
#channel_relation = F.softmax(channel_relation)
x = channel_relation.bmm(x) #trilinear attention map: b*c*(h*w)
x = x.mean(2)
return x

#def _signed_sqrt(self, x):
# x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh))
# return x

#def _l2norm(self, x):
# x = F.normalize(x)
# return x

def forward(self, x):
x = self._trilinearpool(x)
#x = self._signed_sqrt(x)
#if self.is_vec:
# x = x.view(x.size(0),-1)
#x = self._l2norm(x)
return x

0 comments on commit 49efc25

Please sign in to comment.