-
Notifications
You must be signed in to change notification settings - Fork 32
/
TRILINEAR.py
36 lines (31 loc) · 1.12 KB
/
TRILINEAR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
class TRILINEAR(nn.Module):
def __init__(self, is_vec, input_dim=2048):
super(TRILINEAR, self).__init__()
#self.thresh = thresh
self.is_vec = is_vec
self.output_dim = input_dim
def _trilinearpool(self, x):
batchSize, dim, h, w = x.data.shape
x = x.reshape(batchSize, dim, h * w)
#x = 1. / (h * w) * x.bmm(x.transpose(1, 2))
x_norm = F.softmax(dim=2)
channel_relation = x_norm.bmm(x.transpose(1, 2)) #inter-channel relationship map
#channel_relation = F.softmax(channel_relation)
x = channel_relation.bmm(x) #trilinear attention map: b*c*(h*w)
return x
#def _signed_sqrt(self, x):
# x = torch.mul(x.sign(), torch.sqrt(x.abs()+self.thresh))
# return x
#def _l2norm(self, x):
# x = F.normalize(x)
# return x
def forward(self, x):
x = self._trilinearpool(x)
#x = self._signed_sqrt(x)
if self.is_vec:
x = x.mean(2).squeeze()
#x = self._l2norm(x)
return x